I’ve spent years working with Keras to build efficient deep learning models, and one challenge I often face is training image classifiers with limited labeled data. Semi-supervised learning, combined with contrastive pretraining, offers a powerful solution.
In this tutorial, I’ll walk you through how to use SimCLR, a popular contrastive learning framework, to pretrain your model and then fine-tune it for image classification using Keras.
This approach leverages unlabeled data effectively, making it ideal when labeled images are scarce but unlabeled images are abundant.
What is Semi-Supervised Image Classification with SimCLR in Keras?
Semi-supervised learning uses a small amount of labeled data alongside a larger pool of unlabeled data. SimCLR is a contrastive learning method that learns useful image representations without labels by maximizing agreement between differently augmented views of the same image.
Using Keras, we can pretrain a neural network with SimCLR on unlabeled images, then fine-tune it on labeled data for classification. This results in better performance compared to training from scratch.
Method 1: Contrastive Pretraining with SimCLR in Keras
The core idea is to train a model to distinguish between augmented versions of the same image (positive pairs) and other images (negative pairs).
Step 1: Data Augmentation for Contrastive Learning
SimCLR relies heavily on strong data augmentations. Here’s a simple augmentation pipeline using TensorFlow:
import tensorflow as tf
def data_augment(image):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.5)
image = tf.image.random_contrast(image, lower=0.1, upper=0.6)
image = tf.image.random_saturation(image, lower=0.1, upper=0.6)
image = tf.image.random_hue(image, max_delta=0.2)
return imageThis function creates diverse views of the same image, essential for contrastive learning.
Step 2: Creating the SimCLR Model in Keras
I use a ResNet50 backbone without the top classification layer as the encoder. Then, I add a projection head to map representations to a space where contrastive loss is applied.
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
def create_simclr_model(input_shape=(224, 224, 3), projection_dim=128):
base_model = ResNet50(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
inputs = layers.Input(shape=input_shape)
features = base_model(inputs)
# Projection head
x = layers.Dense(512, activation='relu')(features)
x = layers.Dense(projection_dim)(x)
model = models.Model(inputs, x)
return modelThis model outputs embeddings used for contrastive loss.
Step 3: Implementing Contrastive Loss
I use the NT-Xent loss, which encourages embeddings of augmented pairs to be close while pushing apart others.
import tensorflow.keras.backend as K
def nt_xent_loss(batch_size, temperature=0.5):
def loss(y_true, y_pred):
y_pred = tf.math.l2_normalize(y_pred, axis=1)
similarity_matrix = tf.matmul(y_pred, y_pred, transpose_b=True)
# Mask to remove similarity with itself
mask = tf.eye(batch_size * 2)
similarity_matrix = similarity_matrix * (1 - mask)
positives = tf.linalg.diag_part(similarity_matrix, k=batch_size) + tf.linalg.diag_part(similarity_matrix, k=-batch_size)
numerator = tf.exp(positives / temperature)
denominator = tf.reduce_sum(tf.exp(similarity_matrix / temperature), axis=1)
loss = -tf.math.log(numerator / denominator)
return tf.reduce_mean(loss)
return lossThis function computes the loss over a batch of positive and negative pairs.
Step 4: Preparing the Dataset for SimCLR
For each image, generate two augmented versions and create batches of size 2 * batch_size.
def prepare_simclr_dataset(dataset, batch_size):
def augment_twice(image, label):
return data_augment(image), data_augment(image)
dataset = dataset.map(augment_twice)
dataset = dataset.batch(batch_size)
return datasetYou can refer to the screenshot below to see the output.

This prepares the data in the required format.
Method 2: Fine-Tuning the Pretrained Encoder for Image Classification in Keras
After pretraining, I freeze the encoder weights and add a classification head.
def create_classifier(encoder, num_classes):
encoder.trainable = False
inputs = layers.Input(shape=(224, 224, 3))
features = encoder(inputs)
x = layers.Dense(256, activation='relu')(features)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs, outputs)
return modelThis model can be trained on labeled data with categorical cross-entropy loss.
Full Example: Semi-Supervised Image Classification with SimCLR in Keras
Here’s the complete code combining both pretraining and fine-tuning:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
# Data augmentation
def data_augment(image):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.5)
image = tf.image.random_contrast(image, lower=0.1, upper=0.6)
image = tf.image.random_saturation(image, lower=0.1, upper=0.6)
image = tf.image.random_hue(image, max_delta=0.2)
return image
# SimCLR model
def create_simclr_model(input_shape=(224, 224, 3), projection_dim=128):
base_model = ResNet50(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
inputs = layers.Input(shape=input_shape)
features = base_model(inputs)
x = layers.Dense(512, activation='relu')(features)
x = layers.Dense(projection_dim)(x)
model = models.Model(inputs, x)
return model
# NT-Xent loss
def nt_xent_loss(batch_size, temperature=0.5):
def loss(y_true, y_pred):
y_pred = tf.math.l2_normalize(y_pred, axis=1)
similarity_matrix = tf.matmul(y_pred, y_pred, transpose_b=True)
mask = tf.eye(batch_size * 2)
similarity_matrix = similarity_matrix * (1 - mask)
positives = tf.linalg.diag_part(similarity_matrix, k=batch_size) + tf.linalg.diag_part(similarity_matrix, k=-batch_size)
numerator = tf.exp(positives / temperature)
denominator = tf.reduce_sum(tf.exp(similarity_matrix / temperature), axis=1)
loss = -tf.math.log(numerator / denominator)
return tf.reduce_mean(loss)
return loss
# Dataset preparation
def prepare_simclr_dataset(dataset, batch_size):
def augment_twice(image, label):
return data_augment(image), data_augment(image)
dataset = dataset.map(lambda x, y: (tf.concat([data_augment(x), data_augment(x)], axis=0), y))
dataset = dataset.batch(batch_size * 2)
return dataset
# Fine-tune classifier
def create_classifier(encoder, num_classes):
encoder.trainable = False
inputs = layers.Input(shape=(224, 224, 3))
features = encoder(inputs)
x = layers.Dense(256, activation='relu')(features)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs, outputs)
return model
# Example usage
if __name__ == "__main__":
batch_size = 32
input_shape = (224, 224, 3)
num_classes = 10 # Example number of classes
# Load your unlabeled dataset here (e.g., tf.data.Dataset)
# For demonstration, using random data
unlabeled_images = tf.random.uniform((1000, 224, 224, 3))
unlabeled_labels = tf.zeros((1000,)) # Dummy labels
unlabeled_dataset = tf.data.Dataset.from_tensor_slices((unlabeled_images, unlabeled_labels))
unlabeled_dataset = unlabeled_dataset.batch(batch_size)
# Create SimCLR model
simclr_model = create_simclr_model(input_shape=input_shape)
# Compile model with contrastive loss
simclr_model.compile(optimizer='adam', loss=nt_xent_loss(batch_size))
# Pretrain with unlabeled data
# Note: This example uses dummy data and may not run as-is for real training
simclr_model.fit(unlabeled_dataset, epochs=10)
# After pretraining, create classifier
encoder = models.Model(simclr_model.input, simclr_model.layers[-1].input)
classifier = create_classifier(encoder, num_classes=num_classes)
# Compile classifier
classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Load labeled dataset here
# For demonstration, using random data
labeled_images = tf.random.uniform((200, 224, 224, 3))
labeled_labels = tf.random.uniform((200,), maxval=num_classes, dtype=tf.int32)
labeled_dataset = tf.data.Dataset.from_tensor_slices((labeled_images, labeled_labels)).batch(batch_size)
# Fine-tune classifier on labeled data
classifier.fit(labeled_dataset, epochs=5)This method allows you to leverage large amounts of unlabeled images for pretraining, improving classification performance on limited labeled data.
I hope this tutorial helps you understand how to implement semi-supervised image classification with contrastive pretraining using SimCLR in Keras. If you have any questions or want to share your experience, feel free to leave a comment below. Happy coding!
Other Python Keras tutorials you may also like:
- Image Classification with ConvMixer in Keras
- Image Classification Using EANet in Python Keras
- Involutional Neural Networks in Python Using Keras
- Image Classification with Perceiver in Keras

Bijay Kumar is an experienced Python and AI professional who enjoys helping developers learn modern technologies through practical tutorials and examples. His expertise includes Python development, Machine Learning, Artificial Intelligence, automation, and data analysis using libraries like Pandas, NumPy, TensorFlow, Matplotlib, SciPy, and Scikit-Learn. At PythonGuides.com, he shares in-depth guides designed for both beginners and experienced developers. More about us.