Data Parallel Training with KerasHub and tf.distribute

Scaling deep learning models used to be a daunting task that required complex configurations and deep infrastructure knowledge.

During my four years as a Keras developer, I have found that KerasHub combined with tf.distribute simplifies this process immensely.

In this tutorial, I will show you exactly how to implement data parallel training to speed up your model training across multiple GPUs.

Whether you are working with large language models or computer vision, these techniques will help you handle massive datasets efficiently.

Data Parallelism in Keras 3

Data parallelism is a technique where we split the training data into smaller batches and process them simultaneously on different devices.

Each device holds a full copy of the model, and after each step, the gradients are synced to ensure the model stays consistent.

import os
import keras
import keras_hub
import tensorflow as tf

# Check available devices
print("Devices available:", tf.config.list_physical_devices())

I have used this approach frequently when my training data outgrows the processing power of a single high-end GPU.

It is the most common way to scale because it is straightforward to implement and provides almost linear speedup in many cases.

Set Up the MirroredStrategy for Local Multi-GPU Training

The tf.distribute.MirroredStrategy is my go-to choice when I have multiple GPUs connected to a single machine or workstation.

It creates a copy of all variables on each processor and uses all-reduce to keep everything in sync during the training loop.

# Initialize the mirrored strategy
strategy = tf.distribute.MirroredStrategy()

print(f"Number of devices in sync: {strategy.num_replicas_in_sync}")

Using this strategy ensures that your KerasHub models automatically utilize every bit of hardware power available locally.

I always check the num_replicas_in_sync property to verify that the system correctly identifies all the installed graphics cards.

Load a KerasHub Model within a Distributed Scope

To make the model aware of the distribution strategy, you must initialize it inside the strategy.scope() block.

KerasHub makes this easy by allowing us to pull pre-trained backbones like BERT or Llama and wrap them in a distributed context.

# Define the model loading within the distribution scope
with strategy.scope():
    # Loading a pre-trained BERT backbone for a text classification task
    classifier = keras_hub.models.BertClassifier.from_preset(
        "bert_tiny_en_uncased",
        num_classes=2
    )
    
    # Compile the model with distributed-aware optimizer
    classifier.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-5),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"]
    )

By placing the from_preset call inside the scope, Keras ensures that the model weights are mirrored across all your GPUs.

In my experience, forgetting to wrap the model initialization is the number one reason why developers see no performance gain.

Prepare Distributed Datasets for Keras Training

When training in parallel, your dataset must be converted into a tf.data.Dataset object to be properly distributed.

This allows the strategy to shard the data, meaning different batches are sent to different GPUs at the same time.

import numpy as np

# Simulating a dataset of US consumer reviews
num_samples = 1000
features = ["Great service at the New York branch", "The delivery to Chicago was delayed"] * 500
labels = np.random.randint(0, 2, size=(num_samples,))

def prepare_dataset(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((features, labels))
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

# Ensure the batch size is a multiple of the number of GPUs
GLOBAL_BATCH_SIZE = 32
train_dataset = prepare_dataset(features, labels, GLOBAL_BATCH_SIZE)

I highly recommend using tf.data.AUTOTUNE to prevent your CPU from becoming a bottleneck during the data loading process.

Always ensure your global batch size is divisible by the number of GPUs to keep the workload balanced across the cluster.

Execute the Distributed Training Loop in Keras

Once the model and dataset are ready, the actual training process looks exactly like a standard Keras fit() call.

The tf.distribute engine handles the background complexity of syncing gradients and updating weights across all devices.

# Start the distributed training process
history = classifier.fit(
    train_dataset,
    epochs=3,
    verbose=1
)

# Verify training completion
print("Training session finished successfully.")

You will notice that each epoch runs significantly faster compared to training on a single device without any extra code changes.

I often monitor the logs to ensure the loss is decreasing uniformly, which indicates that the gradient synchronization is working correctly.

Implement MultiWorkerMirroredStrategy for Cloud Clusters

If your project requires more power than a single machine can provide, the MultiWorkerMirroredStrategy is the standard solution.

This method allows you to link multiple servers together, which is common when training on cloud platforms like AWS or Google Cloud.

# Example of setting up Multi-Worker training
# Note: This usually requires a TF_CONFIG environment variable
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
multi_worker_strategy = tf.distribute.MultiWorkerMirroredStrategy(cluster_resolver)

with multi_worker_strategy.scope():
    # Re-initialize or load the KerasHub model for the cluster
    multi_model = keras_hub.models.BertClassifier.from_preset(
        "bert_tiny_en_uncased",
        num_classes=2
    )
    multi_model.compile(optimizer="adam", loss="binary_crossentropy")

Setting up the TF_CONFIG is essential here, as it tells each machine its role (worker or lead) and how to talk to others.

I’ve used this many times for production-level training where we need to process millions of rows of data within a few hours.

Save and Exporting Distributed KerasHub Models

After a long training session, you need to save your model in a way that it can be used later for inference on a single device.

Saving outside of the strategy scope is a best practice to ensure the saved file is a standard Keras model.

# Save the model after training is done
model_path = "us_sentiment_model.keras"
classifier.save(model_path)

# Reloading for standard inference
reloaded_model = keras.models.load_model(model_path)
print("Model saved and reloaded for inference.")

You can refer to the screenshot below to see the output.

Data Parallel Training with KerasHub and tf.distribute

When you save the model, Keras collapses the mirrored variables back into a single set of weights.

I find this extremely helpful because it means my deployment pipeline doesn’t need to know anything about the multi-GPU setup.

In this tutorial, I showed you how to use KerasHub and tf.distribute to scale your training efficiently.

I have used these methods in several large-scale projects to cut down training time from days to hours.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.