I’ve spent years working with Keras to build efficient deep learning models, and one challenge I often face is training image classifiers with limited labeled data. Semi-supervised learning, combined with contrastive pretraining, offers a powerful solution.
In this tutorial, I’ll walk you through how to use SimCLR, a popular contrastive learning framework, to pretrain your model and then fine-tune it for image classification using Keras.
This approach leverages unlabeled data effectively, making it ideal when labeled images are scarce but unlabeled images are abundant.
What is Semi-Supervised Image Classification with SimCLR in Keras?
Semi-supervised learning uses a small amount of labeled data alongside a larger pool of unlabeled data. SimCLR is a contrastive learning method that learns useful image representations without labels by maximizing agreement between differently augmented views of the same image.
Using Keras, we can pretrain a neural network with SimCLR on unlabeled images, then fine-tune it on labeled data for classification. This results in better performance compared to training from scratch.
Method 1: Contrastive Pretraining with SimCLR in Keras
The core idea is to train a model to distinguish between augmented versions of the same image (positive pairs) and other images (negative pairs).
Step 1: Data Augmentation for Contrastive Learning
SimCLR relies heavily on strong data augmentations. Here’s a simple augmentation pipeline using TensorFlow:
import tensorflow as tf
def data_augment(image):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.5)
image = tf.image.random_contrast(image, lower=0.1, upper=0.6)
image = tf.image.random_saturation(image, lower=0.1, upper=0.6)
image = tf.image.random_hue(image, max_delta=0.2)
return imageThis function creates diverse views of the same image, essential for contrastive learning.
Step 2: Creating the SimCLR Model in Keras
I use a ResNet50 backbone without the top classification layer as the encoder. Then, I add a projection head to map representations to a space where contrastive loss is applied.
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
def create_simclr_model(input_shape=(224, 224, 3), projection_dim=128):
base_model = ResNet50(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
inputs = layers.Input(shape=input_shape)
features = base_model(inputs)
# Projection head
x = layers.Dense(512, activation='relu')(features)
x = layers.Dense(projection_dim)(x)
model = models.Model(inputs, x)
return modelThis model outputs embeddings used for contrastive loss.
Step 3: Implementing Contrastive Loss
I use the NT-Xent loss, which encourages embeddings of augmented pairs to be close while pushing apart others.
import tensorflow.keras.backend as K
def nt_xent_loss(batch_size, temperature=0.5):
def loss(y_true, y_pred):
y_pred = tf.math.l2_normalize(y_pred, axis=1)
similarity_matrix = tf.matmul(y_pred, y_pred, transpose_b=True)
# Mask to remove similarity with itself
mask = tf.eye(batch_size * 2)
similarity_matrix = similarity_matrix * (1 - mask)
positives = tf.linalg.diag_part(similarity_matrix, k=batch_size) + tf.linalg.diag_part(similarity_matrix, k=-batch_size)
numerator = tf.exp(positives / temperature)
denominator = tf.reduce_sum(tf.exp(similarity_matrix / temperature), axis=1)
loss = -tf.math.log(numerator / denominator)
return tf.reduce_mean(loss)
return lossThis function computes the loss over a batch of positive and negative pairs.
Step 4: Preparing the Dataset for SimCLR
For each image, generate two augmented versions and create batches of size 2 * batch_size.
def prepare_simclr_dataset(dataset, batch_size):
def augment_twice(image, label):
return data_augment(image), data_augment(image)
dataset = dataset.map(augment_twice)
dataset = dataset.batch(batch_size)
return datasetYou can refer to the screenshot below to see the output.

This prepares the data in the required format.
Method 2: Fine-Tuning the Pretrained Encoder for Image Classification in Keras
After pretraining, I freeze the encoder weights and add a classification head.
def create_classifier(encoder, num_classes):
encoder.trainable = False
inputs = layers.Input(shape=(224, 224, 3))
features = encoder(inputs)
x = layers.Dense(256, activation='relu')(features)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs, outputs)
return modelThis model can be trained on labeled data with categorical cross-entropy loss.
Full Example: Semi-Supervised Image Classification with SimCLR in Keras
Here’s the complete code combining both pretraining and fine-tuning:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
# Data augmentation
def data_augment(image):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.5)
image = tf.image.random_contrast(image, lower=0.1, upper=0.6)
image = tf.image.random_saturation(image, lower=0.1, upper=0.6)
image = tf.image.random_hue(image, max_delta=0.2)
return image
# SimCLR model
def create_simclr_model(input_shape=(224, 224, 3), projection_dim=128):
base_model = ResNet50(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
inputs = layers.Input(shape=input_shape)
features = base_model(inputs)
x = layers.Dense(512, activation='relu')(features)
x = layers.Dense(projection_dim)(x)
model = models.Model(inputs, x)
return model
# NT-Xent loss
def nt_xent_loss(batch_size, temperature=0.5):
def loss(y_true, y_pred):
y_pred = tf.math.l2_normalize(y_pred, axis=1)
similarity_matrix = tf.matmul(y_pred, y_pred, transpose_b=True)
mask = tf.eye(batch_size * 2)
similarity_matrix = similarity_matrix * (1 - mask)
positives = tf.linalg.diag_part(similarity_matrix, k=batch_size) + tf.linalg.diag_part(similarity_matrix, k=-batch_size)
numerator = tf.exp(positives / temperature)
denominator = tf.reduce_sum(tf.exp(similarity_matrix / temperature), axis=1)
loss = -tf.math.log(numerator / denominator)
return tf.reduce_mean(loss)
return loss
# Dataset preparation
def prepare_simclr_dataset(dataset, batch_size):
def augment_twice(image, label):
return data_augment(image), data_augment(image)
dataset = dataset.map(lambda x, y: (tf.concat([data_augment(x), data_augment(x)], axis=0), y))
dataset = dataset.batch(batch_size * 2)
return dataset
# Fine-tune classifier
def create_classifier(encoder, num_classes):
encoder.trainable = False
inputs = layers.Input(shape=(224, 224, 3))
features = encoder(inputs)
x = layers.Dense(256, activation='relu')(features)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs, outputs)
return model
# Example usage
if __name__ == "__main__":
batch_size = 32
input_shape = (224, 224, 3)
num_classes = 10 # Example number of classes
# Load your unlabeled dataset here (e.g., tf.data.Dataset)
# For demonstration, using random data
unlabeled_images = tf.random.uniform((1000, 224, 224, 3))
unlabeled_labels = tf.zeros((1000,)) # Dummy labels
unlabeled_dataset = tf.data.Dataset.from_tensor_slices((unlabeled_images, unlabeled_labels))
unlabeled_dataset = unlabeled_dataset.batch(batch_size)
# Create SimCLR model
simclr_model = create_simclr_model(input_shape=input_shape)
# Compile model with contrastive loss
simclr_model.compile(optimizer='adam', loss=nt_xent_loss(batch_size))
# Pretrain with unlabeled data
# Note: This example uses dummy data and may not run as-is for real training
simclr_model.fit(unlabeled_dataset, epochs=10)
# After pretraining, create classifier
encoder = models.Model(simclr_model.input, simclr_model.layers[-1].input)
classifier = create_classifier(encoder, num_classes=num_classes)
# Compile classifier
classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Load labeled dataset here
# For demonstration, using random data
labeled_images = tf.random.uniform((200, 224, 224, 3))
labeled_labels = tf.random.uniform((200,), maxval=num_classes, dtype=tf.int32)
labeled_dataset = tf.data.Dataset.from_tensor_slices((labeled_images, labeled_labels)).batch(batch_size)
# Fine-tune classifier on labeled data
classifier.fit(labeled_dataset, epochs=5)This method allows you to leverage large amounts of unlabeled images for pretraining, improving classification performance on limited labeled data.
I hope this tutorial helps you understand how to implement semi-supervised image classification with contrastive pretraining using SimCLR in Keras. If you have any questions or want to share your experience, feel free to leave a comment below. Happy coding!
Other Python Keras tutorials you may also like:
- Image Classification with ConvMixer in Keras
- Image Classification Using EANet in Python Keras
- Involutional Neural Networks in Python Using Keras
- Image Classification with Perceiver in Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.