Batch Normalization in TensorFlow

Recently, I was working on a deep learning project where my model was taking forever to train, and the accuracy was all over the place. The issue was, I wasn’t using batch normalization. Once I implemented it correctly, my training time decreased by 40%, and the model became much more stable.

In this guide, I’ll cover everything you need to know about batch normalization in TensorFlow, from the basic concepts to implementation techniques that will improve your deep learning models.

So let’s get started!

Batch Normalization

Batch normalization is a technique that normalizes the inputs of each layer, making neural networks faster to train and more stable. It helps address the problem of internal covariate shift, where the distribution of inputs to layers changes during training.

Think of it as keeping your data in check throughout the entire network, not just at the beginning.

I’ve found that batch normalization often means the difference between a model that trains successfully and one that doesn’t converge at all.

Method 1: Add Batch Normalization to Dense Layers

Let’s start with the simplest implementation – adding batch normalization to dense layers in TensorFlow:

import tensorflow as tf
from tensorflow.keras import layers

# Step 1: Define the model creation function
def create_model_with_batchnorm():
    model = tf.keras.Sequential([
        layers.Dense(256, input_shape=(784,)),   # Input layer expects 784 features (e.g. flattened 28x28 image)
        layers.BatchNormalization(),
        layers.Activation('relu'),

        layers.Dense(128),
        layers.BatchNormalization(),
        layers.Activation('relu'),

        layers.Dense(10, activation='softmax')   # Output layer for 10 classes (e.g. digits 0-9)
    ])

    return model

# Step 2: Create the model
model = create_model_with_batchnorm()

# Step 3: Compile the model - this sets up optimizer, loss, and metrics
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Step 4: Print the model summary to the console
model.summary()

# Step 5: (Optional) Train the model with dummy data to see training output

import numpy as np

# Create some fake data - 100 samples, 784 features each (like flattened 28x28 images)
X_train = np.random.rand(100, 784).astype('float32')

# Create fake labels (integers from 0 to 9) for 100 samples
y_train = np.random.randint(0, 10, size=(100,))

# Train the model for 5 epochs on this fake data
model.fit(X_train, y_train, epochs=5, batch_size=32)

You can see the output in the screenshot below.

tensorflow batch normalization

Notice how I’ve separated the Dense layer from its activation function. This is because batch normalization is typically applied after the linear transformation but before the activation function.

I’ve found this pattern works best in most situations, as it normalizes the inputs to the activation functions, making them more effective.

Check out TensorFlow Fully Connected Layer

Method 2: Batch Normalization with Convolutional Networks

For CNN architectures, batch normalization is equally effective:

import tensorflow as tf
from tensorflow.keras import layers, datasets
import numpy as np

# Define the model
def create_cnn_with_batchnorm():
    model = tf.keras.Sequential([
        layers.Conv2D(32, (3, 3), padding='same', input_shape=(28, 28, 1)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.MaxPooling2D((2, 2)),

        layers.Conv2D(64, (3, 3), padding='same'),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.MaxPooling2D((2, 2)),

        layers.Flatten(),
        layers.Dense(128),
        layers.BatchNormalization(),
        layers.Activation('relu'),

        layers.Dense(10, activation='softmax')
    ])
    return model

# Load MNIST data
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

# Normalize and reshape data
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = np.expand_dims(x_train, -1)  # (num_samples, 28, 28, 1)
x_test = np.expand_dims(x_test, -1)

# Create model
model = create_cnn_with_batchnorm()

# Compile model
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Show model architecture
model.summary()

# Train model
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.2)

# Evaluate model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"\nTest accuracy: {test_acc:.4f}")

You can see the output in the screenshot below.

batch normalization tensorflow

When working with image data, I’ve noticed that batch normalization particularly shines with CNN architectures, often improving both accuracy and training speed.

Method 3: Configure Batch Normalization Parameters

TensorFlow’s BatchNormalization layer has several parameters you can adjust:

# Creating a batch normalization layer with custom parameters
bn_layer = layers.BatchNormalization(
    axis=-1,                # Normalize along the last axis
    momentum=0.99,          # Momentum for moving average
    epsilon=0.001,          # Small constant added for numerical stability
    center=True,            # If True, add offset of beta
    scale=True,             # If True, multiply by gamma
    beta_initializer='zeros',
    gamma_initializer='ones'
)

In my experience, the default parameters work well for most cases, but you might want to adjust them for specific scenarios:

  • Lower momentum values (like 0.9) make the moving average adapt more quickly
  • Higher epsilon values can help when dealing with very small numbers
  • Setting center=False or scale=False can slightly reduce the parameter count

Check out Tensorflow Convert String to Int

Method 4: Use Batch Normalization in Custom Training Loops

If you’re using custom training loops in TensorFlow, you need to be aware of the training flag:

def custom_training_step(model, optimizer, x_batch, y_batch):
    with tf.GradientTape() as tape:
        # Set training=True to use batch statistics during training
        predictions = model(x_batch, training=True)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

This is crucial – during training, batch normalization uses statistics from the current batch, but during inference, it uses the running statistics that were accumulated during training.

I’ve debugged countless models where performance dropped during inference because this flag wasn’t properly set.

Read TensorFlow Variable

Method 5: Batch Normalization with Transfer Learning

When fine-tuning pre-trained models, you might want to freeze the batch normalization layers:

def create_transfer_learning_model():
    base_model = tf.keras.applications.ResNet50(include_top=False, 
                                               weights='imagenet',
                                               input_shape=(224, 224, 3))

    # Freeze batch normalization layers
    for layer in base_model.layers:
        if isinstance(layer, layers.BatchNormalization):
            layer.trainable = False

    model = tf.keras.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(1024),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Dense(10, activation='softmax')
    ])

    return model

This approach has saved me lots of trouble when fine-tuning large pre-trained models. The pre-trained batch normalization layers already have good statistics for the base model, and updating them with small batches during fine-tuning can degrade performance.

Check out Tensor in TensorFlow

Real-World Example: MNIST Digit Classification

Let’s see batch normalization in action with a real example:

import tensorflow as tf
from tensorflow.keras.datasets import mnist

# Load and preprocess data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0

# Create models for comparison
def create_standard_model():
    return tf.keras.Sequential([
        layers.Dense(256, activation='relu', input_shape=(784,)),
        layers.Dense(128, activation='relu'),
        layers.Dense(10, activation='softmax')
    ])

def create_batchnorm_model():
    return tf.keras.Sequential([
        layers.Dense(256, input_shape=(784,)),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Dense(128),
        layers.BatchNormalization(),
        layers.Activation('relu'),
        layers.Dense(10, activation='softmax')
    ])

# Compile and train both models
standard_model = create_standard_model()
batchnorm_model = create_batchnorm_model()

standard_model.compile(optimizer='adam', 
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])

batchnorm_model.compile(optimizer='adam', 
                       loss='sparse_categorical_crossentropy',
                       metrics=['accuracy'])

# Use a higher learning rate for the batch normalized model
batchnorm_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
                       loss='sparse_categorical_crossentropy',
                       metrics=['accuracy'])

print("Training standard model:")
standard_model.fit(x_train, y_train, epochs=5, batch_size=128, validation_split=0.1, verbose=2)

print("\nTraining batch normalized model:")
batchnorm_model.fit(x_train, y_train, epochs=5, batch_size=128, validation_split=0.1, verbose=2)

With the standard model, I typically see around 97% accuracy after 5 epochs. With batch normalization, I often reach 98-99% in the same time, even with a higher learning rate.

Read Compile Neural Network in Tensorflow

Common Pitfalls and How to Avoid Them

Through my years of experience, I’ve encountered several common mistakes when using batch normalization:

  1. Incorrect Order: Always place batch normalization after the linear transformation (Dense/Conv) but before the activation.
  2. Forgetting the Training Flag: Make sure to set training=True during training and training=False during inference.
  3. Small Batch Sizes: Batch normalization works best with reasonably sized batches (32+). Very small batches can lead to noisy statistics.
  4. Normalizing Activations: Don’t place batch normalization after the activation function (unless you have a specific reason).
  5. Using with RNNs: Batch normalization with RNNs requires special handling – consider layer normalization instead.

I’ve made all these mistakes at some point, and fixing them has consistently improved my models.

Check out TensorFlow Ecosystem: Guide to Tools, Libraries & Deployment

When NOT to Use Batch Normalization

While batch normalization is powerful, it’s not always the right choice:

  • For very small batch sizes (< 8), consider using Layer Normalization instead
  • In reinforcement learning with varying batch statistics
  • When you need deterministic outputs regardless of batch size
  • In some RNN architectures (although TensorFlow now has specialized versions for RNNs)

I hope you found this article helpful. Batch normalization is one of those techniques that dramatically improved my neural networks once I started using it correctly. It’s particularly useful for deeper models where training stability becomes an issue.

If you have any questions or suggestions, kindly leave them in the comments below.

Other TensorFlow articles you may also like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.