Semi-Supervised Image Classification with Contrastive Pretraining Using SimCLR in Keras

I’ve spent years working with Keras to build efficient deep learning models, and one challenge I often face is training image classifiers with limited labeled data. Semi-supervised learning, combined with contrastive pretraining, offers a powerful solution.

In this tutorial, I’ll walk you through how to use SimCLR, a popular contrastive learning framework, to pretrain your model and then fine-tune it for image classification using Keras.

This approach leverages unlabeled data effectively, making it ideal when labeled images are scarce but unlabeled images are abundant.

What is Semi-Supervised Image Classification with SimCLR in Keras?

Semi-supervised learning uses a small amount of labeled data alongside a larger pool of unlabeled data. SimCLR is a contrastive learning method that learns useful image representations without labels by maximizing agreement between differently augmented views of the same image.

Using Keras, we can pretrain a neural network with SimCLR on unlabeled images, then fine-tune it on labeled data for classification. This results in better performance compared to training from scratch.

Method 1: Contrastive Pretraining with SimCLR in Keras

The core idea is to train a model to distinguish between augmented versions of the same image (positive pairs) and other images (negative pairs).

Step 1: Data Augmentation for Contrastive Learning

SimCLR relies heavily on strong data augmentations. Here’s a simple augmentation pipeline using TensorFlow:

import tensorflow as tf

def data_augment(image):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.5)
    image = tf.image.random_contrast(image, lower=0.1, upper=0.6)
    image = tf.image.random_saturation(image, lower=0.1, upper=0.6)
    image = tf.image.random_hue(image, max_delta=0.2)
    return image

This function creates diverse views of the same image, essential for contrastive learning.

Step 2: Creating the SimCLR Model in Keras

I use a ResNet50 backbone without the top classification layer as the encoder. Then, I add a projection head to map representations to a space where contrastive loss is applied.

from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50

def create_simclr_model(input_shape=(224, 224, 3), projection_dim=128):
    base_model = ResNet50(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
    inputs = layers.Input(shape=input_shape)
    features = base_model(inputs)
    # Projection head
    x = layers.Dense(512, activation='relu')(features)
    x = layers.Dense(projection_dim)(x)
    model = models.Model(inputs, x)
    return model

This model outputs embeddings used for contrastive loss.

Step 3: Implementing Contrastive Loss

I use the NT-Xent loss, which encourages embeddings of augmented pairs to be close while pushing apart others.

import tensorflow.keras.backend as K

def nt_xent_loss(batch_size, temperature=0.5):
    def loss(y_true, y_pred):
        y_pred = tf.math.l2_normalize(y_pred, axis=1)
        similarity_matrix = tf.matmul(y_pred, y_pred, transpose_b=True)
        # Mask to remove similarity with itself
        mask = tf.eye(batch_size * 2)
        similarity_matrix = similarity_matrix * (1 - mask)
        positives = tf.linalg.diag_part(similarity_matrix, k=batch_size) + tf.linalg.diag_part(similarity_matrix, k=-batch_size)
        numerator = tf.exp(positives / temperature)
        denominator = tf.reduce_sum(tf.exp(similarity_matrix / temperature), axis=1)
        loss = -tf.math.log(numerator / denominator)
        return tf.reduce_mean(loss)
    return loss

This function computes the loss over a batch of positive and negative pairs.

Step 4: Preparing the Dataset for SimCLR

For each image, generate two augmented versions and create batches of size 2 * batch_size.

def prepare_simclr_dataset(dataset, batch_size):
    def augment_twice(image, label):
        return data_augment(image), data_augment(image)
    dataset = dataset.map(augment_twice)
    dataset = dataset.batch(batch_size)
    return dataset

You can refer to the screenshot below to see the output.

Semi-Supervised Image Classification with Contrastive Pretraining Using SimCLR

This prepares the data in the required format.

Method 2: Fine-Tuning the Pretrained Encoder for Image Classification in Keras

After pretraining, I freeze the encoder weights and add a classification head.

def create_classifier(encoder, num_classes):
    encoder.trainable = False
    inputs = layers.Input(shape=(224, 224, 3))
    features = encoder(inputs)
    x = layers.Dense(256, activation='relu')(features)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)
    return model

This model can be trained on labeled data with categorical cross-entropy loss.

Full Example: Semi-Supervised Image Classification with SimCLR in Keras

Here’s the complete code combining both pretraining and fine-tuning:

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50

# Data augmentation
def data_augment(image):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.5)
    image = tf.image.random_contrast(image, lower=0.1, upper=0.6)
    image = tf.image.random_saturation(image, lower=0.1, upper=0.6)
    image = tf.image.random_hue(image, max_delta=0.2)
    return image

# SimCLR model
def create_simclr_model(input_shape=(224, 224, 3), projection_dim=128):
    base_model = ResNet50(include_top=False, weights=None, input_shape=input_shape, pooling='avg')
    inputs = layers.Input(shape=input_shape)
    features = base_model(inputs)
    x = layers.Dense(512, activation='relu')(features)
    x = layers.Dense(projection_dim)(x)
    model = models.Model(inputs, x)
    return model

# NT-Xent loss
def nt_xent_loss(batch_size, temperature=0.5):
    def loss(y_true, y_pred):
        y_pred = tf.math.l2_normalize(y_pred, axis=1)
        similarity_matrix = tf.matmul(y_pred, y_pred, transpose_b=True)
        mask = tf.eye(batch_size * 2)
        similarity_matrix = similarity_matrix * (1 - mask)
        positives = tf.linalg.diag_part(similarity_matrix, k=batch_size) + tf.linalg.diag_part(similarity_matrix, k=-batch_size)
        numerator = tf.exp(positives / temperature)
        denominator = tf.reduce_sum(tf.exp(similarity_matrix / temperature), axis=1)
        loss = -tf.math.log(numerator / denominator)
        return tf.reduce_mean(loss)
    return loss

# Dataset preparation
def prepare_simclr_dataset(dataset, batch_size):
    def augment_twice(image, label):
        return data_augment(image), data_augment(image)
    dataset = dataset.map(lambda x, y: (tf.concat([data_augment(x), data_augment(x)], axis=0), y))
    dataset = dataset.batch(batch_size * 2)
    return dataset

# Fine-tune classifier
def create_classifier(encoder, num_classes):
    encoder.trainable = False
    inputs = layers.Input(shape=(224, 224, 3))
    features = encoder(inputs)
    x = layers.Dense(256, activation='relu')(features)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs, outputs)
    return model

# Example usage
if __name__ == "__main__":
    batch_size = 32
    input_shape = (224, 224, 3)
    num_classes = 10  # Example number of classes

    # Load your unlabeled dataset here (e.g., tf.data.Dataset)
    # For demonstration, using random data
    unlabeled_images = tf.random.uniform((1000, 224, 224, 3))
    unlabeled_labels = tf.zeros((1000,))  # Dummy labels
    unlabeled_dataset = tf.data.Dataset.from_tensor_slices((unlabeled_images, unlabeled_labels))
    unlabeled_dataset = unlabeled_dataset.batch(batch_size)

    # Create SimCLR model
    simclr_model = create_simclr_model(input_shape=input_shape)

    # Compile model with contrastive loss
    simclr_model.compile(optimizer='adam', loss=nt_xent_loss(batch_size))

    # Pretrain with unlabeled data
    # Note: This example uses dummy data and may not run as-is for real training
    simclr_model.fit(unlabeled_dataset, epochs=10)

    # After pretraining, create classifier
    encoder = models.Model(simclr_model.input, simclr_model.layers[-1].input)
    classifier = create_classifier(encoder, num_classes=num_classes)

    # Compile classifier
    classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    # Load labeled dataset here
    # For demonstration, using random data
    labeled_images = tf.random.uniform((200, 224, 224, 3))
    labeled_labels = tf.random.uniform((200,), maxval=num_classes, dtype=tf.int32)
    labeled_dataset = tf.data.Dataset.from_tensor_slices((labeled_images, labeled_labels)).batch(batch_size)

    # Fine-tune classifier on labeled data
    classifier.fit(labeled_dataset, epochs=5)

This method allows you to leverage large amounts of unlabeled images for pretraining, improving classification performance on limited labeled data.

I hope this tutorial helps you understand how to implement semi-supervised image classification with contrastive pretraining using SimCLR in Keras. If you have any questions or want to share your experience, feel free to leave a comment below. Happy coding!

Other Python Keras tutorials you may also like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.