Deep Learning Stability with Gradient Centralization in Python Keras

In my years of working with deep learning, I have often hit a wall where my models just wouldn’t converge fast enough. I used to spend hours tweaking learning rates, only to find that the weight gradients were becoming unstable during backpropagation.

Then I discovered Gradient Centralization (GC), a simple yet powerful technique that operates directly on the gradients by shifting them to have a zero mean.

In this tutorial, I will show you exactly how I implement Gradient Centralization to get better performance out of my neural networks.

What is Gradient Centralization in Python Keras?

Gradient Centralization is an optimization technique that operates by subtracting the mean of the column vectors of the weight gradient matrix.

I find this particularly useful because it constrains the loss function, which leads to faster training and better generalization on unseen data.

Implement a Custom Optimizer with Gradient Centralization in Python Keras

The most effective way I have found to use GC is by wrapping a standard optimizer into a custom class.

This method allows you to intercept the gradients and centralize them before the optimizer updates the weights.

import tensorflow as tf
from tensorflow import keras
import numpy as np

# This class modifies the gradient step to include centralization
class GCOptimizer(keras.optimizers.Optimizer):
    def __init__(self, optimizer, name="GCOptimizer", **kwargs):
        super().__init__(name, **kwargs)
        self._optimizer = optimizer

    def apply_gradients(self, grads_and_vars, name=None, **kwargs):
        new_grads_and_vars = []
        for grad, var in grads_and_vars:
            if grad is not None and len(grad.shape) > 1:
                # We subtract the mean from the gradient to centralize it
                grad -= tf.reduce_mean(grad, axis=list(range(len(grad.shape) - 1)), keepdims=True)
            new_grads_and_vars.append((grad, var))
        return self._optimizer.apply_gradients(new_grads_and_vars, name=name, **kwargs)

    def get_config(self):
        config = super().get_config()
        config.update({"optimizer": keras.optimizers.serialize(self._optimizer)})
        return config

# Example: Using it with Adam on a house price dataset
model = keras.Sequential([keras.layers.Dense(64, activation='relu', input_shape=(10,)), keras.layers.Dense(1)])
base_opt = keras.optimizers.Adam(learning_rate=0.01)
gc_opt = GCOptimizer(base_opt)

model.compile(optimizer=gc_opt, loss='mse')
print("Optimizer with Gradient Centralization is ready.")

I executed the above example code and added the screenshot below.

Deep Learning Stability with Gradient Centralization in Python Keras

Use the Wrapper Method for Existing Python Keras Optimizers

If you don’t want to build a full class, you can create a simple wrapper function to inject GC logic into any existing Keras optimizer.

I prefer this approach when I am experimenting with different optimizers like SGD or RMSprop on retail sales forecasting models.

def get_centralized_gradients(optimizer, loss, params):
    # This helper function extracts and centralizes gradients manually
    grads = tf.gradients(loss, params)
    centralized_grads = []
    for grad in grads:
        if len(grad.shape) > 1:
            grad = grad - tf.reduce_mean(grad, axis=list(range(len(grad.shape) - 1)), keepdims=True)
        centralized_grads.append(grad)
    return centralized_grads

# Usage in a custom training loop
optimizer = keras.optimizers.SGD(learning_rate=0.01)
print("Manual Gradient Centralization wrapper defined.")

I executed the above example code and added the screenshot below.

Deep Learning Stability with Gradient Centralization in Keras

Apply Gradient Centralization in Python Keras via Model Subclassing

I often use model subclassing when I need more control over the train_step, especially for complex computer vision tasks.

By overriding the train_step, you can apply centralization directly within the model logic without touching the optimizer code.

class GCModel(keras.Model):
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)

        gradients = tape.gradient(loss, self.trainable_variables)
        
        # Centralizing the gradients before applying them
        fc_gradients = []
        for grad in gradients:
            if grad is not None and len(grad.shape) > 1:
                grad = grad - tf.reduce_mean(grad, axis=list(range(len(grad.shape) - 1)), keepdims=True)
            fc_gradients.append(grad)

        self.optimizer.apply_gradients(zip(fc_gradients, self.trainable_variables))
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

# US High-Speed Rail Traffic Prediction Example
inputs = keras.Input(shape=(12,))
outputs = keras.layers.Dense(1)(inputs)
model = GCModel(inputs, outputs)
model.compile(optimizer='adam', loss='mae')
print("Subclassed Model with GC integration complete.")

I executed the above example code and added the screenshot below.

Deep Learning Stability with Gradient Centralization Keras

Implement Gradient Centralization in Python Keras for Convolutional Layers

When I work with image data, like analyzing California wildfire satellite imagery, I apply GC specifically to the 4D tensors of Conv2D layers.

In this method, the mean is calculated across the height, width, and in-channel dimensions to normalize the kernels.

def centralize_conv_grads(grads):
    # Specialized centralization for 4D Convolutional kernels
    new_grads = []
    for g in grads:
        if len(g.shape) == 4: # Filtering for Conv layers
            g -= tf.reduce_mean(g, axis=[0, 1, 2], keepdims=True)
        new_grads.append(g)
    return new_grads

# Example of simulating a gradient update for an image classifier
sample_grads = [tf.random.normal([3, 3, 3, 64])]
processed_grads = centralize_conv_grads(sample_grads)
print(f"Original mean: {tf.reduce_mean(sample_grads[0])}")
print(f"Centralized mean: {tf.reduce_mean(processed_grads[0])}")

A Complete Example: Predict NYC Taxi Fares with Python Keras and GC

To show you how this works in a real scenario, I’ve put together a full script that predicts taxi fares using a dataset similar to the NYC Open Data.

This code includes the custom optimizer and a full training pipeline to demonstrate the performance boost.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# 1. Define the Gradient Centralization Logic
class GCOptimizer(keras.optimizers.Adam):
    def get_gradients(self, loss, params):
        grads = super().get_gradients(loss, params)
        for i in range(len(grads)):
            if len(grads[i].shape) > 1:
                grads[i] -= tf.reduce_mean(grads[i], axis=list(range(len(grads[i].shape) - 1)), keepdims=True)
        return grads

# 2. Create a synthetic dataset for NYC Taxi Fares (Distance, Time, Passengers, etc.)
num_samples = 2000
train_x = np.random.rand(num_samples, 8).astype(np.float32)
train_y = np.random.rand(num_samples, 1).astype(np.float32) * 50 # Fare in Dollars

# 3. Build a Python Keras Regression Model
model = keras.Sequential([
    layers.Dense(128, activation='relu', input_shape=(8,)),
    layers.Dropout(0.2),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)
])

# 4. Compile with the GC Optimizer
optimizer = GCOptimizer(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='huber_loss', metrics=['mae'])

# 5. Train the model
print("Starting training with Gradient Centralization...")
history = model.fit(train_x, train_y, epochs=10, batch_size=32, validation_split=0.2, verbose=1)

print("Training finished successfully.")

I have found that using Gradient Centralization makes my models much more robust against noisy data and high learning rates.

It is a simple addition to your workflow, but the stability it brings to the training process is often the difference between a failing model and a state-of-the-art one.

Give this a try on your next Python Keras project, and you will likely see a faster drop in your loss curves almost immediately.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.