Recently, I was working on a deep learning project, and my model was taking an excessively long time to converge. The training process was frustratingly slow, and the accuracy wasn’t improving as I had hoped. That’s when I decided to implement Batch Normalization, a technique that significantly enhanced my model’s performance and reduced training time.
In this article, I will walk you through everything you need to know about Batch Normalization in PyTorch, from the basic concept to implementation details and best practices. If you’re struggling with slow convergence or unstable training.
So let’s get started..!
Batch Normalization
Batch Normalization (BatchNorm) is a technique introduced in 2015 that normalizes the inputs of each layer, making networks more stable and allowing faster training with higher learning rates.
Think of BatchNorm as a standardization process that happens inside your neural network. It takes the outputs from a layer, normalizes them (mean of 0, variance of 1), and then applies learnable parameters to scale and shift the normalized values.
The key benefits of BatchNorm include:
- Faster training (often 5-10x speedup)
- Higher learning rates without divergence
- Less sensitivity to initialization
- Some regularization effect
Implement Batch Normalization in PyTorch
PyTorch makes it incredibly easy to add BatchNorm layers to your models. Let me show you several ways to implement it.
Method 1: Use nn.BatchNorm2d for CNNs
If you’re working with convolutional neural networks (CNNs), you’ll typically use nn.BatchNorm2d:
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
model = SimpleCNN()
x = torch.randn(1, 3, 32, 32) # Dummy input
out = model(x)
print(out.shape) You can see the output in the screenshot below.

In this example, nn.BatchNorm2d(16) normalizes the 16 feature maps produced by the convolutional layer. The BatchNorm layer automatically tracks running statistics during training and uses them during evaluation.
Method 2: Use nn.BatchNorm1d for Fully Connected Layers
For fully connected (linear) layers, you’ll want to use nn.BatchNorm1d:
import torch
import torch.nn as nn
# Define a simple model with BatchNorm1d after a Linear layer
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.fc = nn.Linear(10, 5)
self.bn = nn.BatchNorm1d(5)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc(x)
x = self.bn(x) # Apply BatchNorm1d after the linear layer
x = self.relu(x)
return x
# Create input with batch size > 1
input_data = torch.randn(3, 10) # batch_size = 3, features = 10
# Initialize and run the model
model = SimpleMLP()
output = model(input_data)
# Print output
print("Output:\n", output)You can see the output in the screenshot below.

Using nn.BatchNorm1d after linear layers help stabilize and accelerate training in fully connected neural networks by normalizing feature activations.
Method 3: Add BatchNorm to Sequential Models
If you’re using nn.Sequential, adding BatchNorm is straightforward:
model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(128 * 8 * 8, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, 10)
)Understand BatchNorm Parameters in PyTorch
When you create a BatchNorm layer in PyTorch, there are several parameters to consider:
torch.nn.BatchNorm2d(
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True
)num_features: Number of features or channels (required)eps: Small constant was added to the variance for numerical stabilitymomentum: Value used for running mean/variance calculationaffine: If True, adds learnable scale and shift parameterstrack_running_stats: Tracks mean and variance during training
Read PyTorch Load Model
Batch Normalization During Training vs. Inference
A key aspect of BatchNorm is that it behaves differently during training and inference:
# Set to training mode - uses batch statistics
model.train()
output = model(input)
# Set to evaluation mode - uses running statistics
model.eval()
prediction = model(test_input)During training (model.train()), BatchNorm calculates mean and variance from the current mini-batch. During inference (model.eval()), it uses the running statistics accumulated during training.
This difference is handled automatically when you switch modes with model.train() and model.eval().
Check out PyTorch Tensor to Numpy
Common Issues and Solutions with BatchNorm
Now, I will explain to you the common issues and solutions that we come across while working with Batch Normalization.
Issue 1: Small Batch Sizes
When using very small batch sizes (e.g., 1 or 2), BatchNorm performance degrades significantly since statistics estimated from such small samples are unreliable.
Solution: Use Group Normalization instead, which doesn’t depend on batch size:
# Replace BatchNorm with GroupNorm
# self.bn1 = nn.BatchNorm2d(16)
self.gn1 = nn.GroupNorm(4, 16) # 4 groups, 16 channelsIssue 2: Recurrent Neural Networks
BatchNorm can be tricky with RNNs because sequence lengths may vary, and temporal dependencies are important.
Solution: Use Layer Normalization which normalizes across features, not batch:
class RNNWithLayerNorm(nn.Module):
def __init__(self, input_size, hidden_size):
super(RNNWithLayerNorm, self).__init__()
self.rnn = nn.LSTM(input_size, hidden_size)
self.ln = nn.LayerNorm(hidden_size)
def forward(self, x):
output, _ = self.rnn(x)
output = self.ln(output)
return outputIssue 3: BatchNorm with Pre-trained Models
When fine-tuning pre-trained models, you might need to handle BatchNorm carefully.
Solution: You can freeze BatchNorm layers if the target domain is similar to the source domain:
# Freeze BatchNorm layers
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
module.eval() # Set to evaluation mode
module.weight.requires_grad = False
module.bias.requires_grad = FalsePractical Example: CIFAR-10 Classification with BatchNorm
Let’s put everything together with a complete example using the CIFAR-10 dataset:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# Data preparation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Model with BatchNorm
class CIFAR10Model(nn.Module):
def __init__(self):
super(CIFAR10Model, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 4 * 4, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# Create model and move to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CIFAR10Model().to(device)
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 100 == 99:
print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
running_loss = 0.0
print('Finished Training')I’ve personally used this approach in numerous projects, and the performance improvement over models without BatchNorm is typically substantial. You’ll notice the model converges faster and often achieves better final accuracy.
BatchNorm has become a standard component in most deep neural networks for good reason; it works! When implemented correctly, it can turn a model that’s struggling to learn into one that trains quickly and generalizes well.
If you’re building deep learning models in PyTorch, I strongly recommend incorporating BatchNorm layers as a default practice. The performance benefits are almost always worth the minimal additional complexity.
Other PyTorch tutorials you may also like:

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.