Supervised Contrastive Learning in Python Keras

When I first started training deep learning models for image classification, I often relied on the standard Cross-Entropy loss. It worked well enough, but I noticed the models struggled when the classes were visually similar.

I discovered that Supervised Contrastive Learning (SupCon) is a game-changer for these scenarios. It helps the model learn to pull similar images together and push different ones apart in the feature space.

In this tutorial, I will show you how to implement Supervised Contrastive Learning using Python Keras. We will move beyond basic examples and build something robust that handles complex data.

What is Supervised Contrastive Learning?

Supervised Contrastive Learning is a training technique that uses label information to group similar samples more effectively. Unlike self-supervised learning, it knows which images belong to the same category.

By using this approach in Python Keras, you can create more generalized features. This is particularly useful when you have a dataset where the differences between classes are very subtle.

Set Up Your Python Keras Environment

Before we dive into the logic, we need to ensure our environment is ready. I always make sure to have the latest version of TensorFlow and Keras installed to avoid compatibility issues.

You will need to import several modules to handle the data augmentation and the custom loss function. Here is the initial setup I use for most of my projects.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# Verify the versions
print(f"TensorFlow version: {tf.__version__}")

Prepare the Dataset with Python Keras

For this example, imagine we are working with a dataset of retail products from a major US department store. We want to distinguish between different types of apparel accurately.

I prefer using the tf.data API because it is incredibly efficient for handling large batches. We will normalize the images and prepare them for the augmentation pipeline.

def prepare_dataset(images, labels, batch_size=128):
    # Normalize pixel values to [0, 1] range
    images = images.astype("float32") / 255.0
    
    # Create a tf.data.Dataset object
    ds = tf.data.Dataset.from_tensor_slices((images, labels))
    ds = ds.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

# Loading sample data (using CIFAR-10 as a placeholder for our retail data)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
train_ds = prepare_dataset(x_train, y_train)

Implement Data Augmentation in Python Keras

Data augmentation is the heart of contrastive learning. We need to create different “views” of the same image so the model learns that a rotated or flipped shirt is still a shirt.

I usually wrap my augmentation layers into a sequential model. This makes it very easy to apply the same transformations across different training stages.

data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
    ]
)

def apply_augmentation(images):
    # Apply random transformations to the input batch
    return data_augmentation(images)

Build the Encoder Network with Python Keras

The encoder is the part of the model that extracts features from the images. I often use a ResNet-based architecture because it provides a solid balance between speed and performance.

In this step, we strip the final classification layer. We only want the high-dimensional vector that represents the image features, often called the “embedding.”

def create_encoder():
    # Load a pre-trained ResNet50V2 without the top head
    base_model = keras.applications.ResNet50V2(
        include_top=False, weights=None, input_shape=(32, 32, 3)
    )
    
    # Add a global average pooling layer
    inputs = layers.Input(shape=(32, 32, 3))
    x = base_model(inputs)
    outputs = layers.GlobalAveragePooling2D()(x)
    
    return keras.Model(inputs, outputs, name="encoder")

encoder = create_encoder()

Add a Projection Head in Python Keras

The projection head is a small MLP that maps the encoder’s output to a lower-dimensional space. I’ve found that this significantly improves the quality of the learned representations.

After training is complete, we usually discard this head and use the encoder’s output for downstream tasks. It acts as a temporary “buffer” for the contrastive loss.

def add_projection_head(encoder):
    # Map the encoder output to a 128-dimensional space
    inputs = layers.Input(shape=(2048,))
    x = layers.Dense(512, activation="relu")(inputs)
    outputs = layers.Dense(128)(x)
    
    return keras.Model(inputs, outputs, name="projection_head")

projection_head = add_projection_head(encoder)

Define the Supervised Contrastive Loss in Python Keras

The loss function is what makes this “Supervised.” It calculates the similarity between all pairs in a batch and penalizes the model if samples with the same label are far apart.

I implement this using matrix multiplication for efficiency. We use a temperature parameter to control how sharply the loss penalizes negative pairs.

class SupervisedContrastiveLoss(keras.losses.Loss):
    def __init__(self, temperature=0.05, name=None):
        super().__init__(name=name)
        self.temperature = temperature

    def call(self, labels, feature_vectors):
        # Normalize the feature vectors
        feature_vectors = tf.math.l2_normalize(feature_vectors, axis=1)
        
        # Compute logits based on cosine similarity
        logits = tf.divide(
            tf.matmul(feature_vectors, tf.transpose(feature_vectors)),
            self.temperature
        )
        return tfa.losses.npairs_loss(tf.squeeze(labels), logits) 
        # Note: In practice, I use a custom implementation of the SupCon formula

Train the SupCon Model with Python Keras

Now we combine everything into a custom training loop or use the Model.compile API. I prefer creating a custom train_step to have full control over the two-view augmentation.

This process trains the encoder to be extremely good at distinguishing between different product categories. It takes a bit longer than standard training but yields much better results.

class SupConModel(keras.Model):
    def __init__(self, encoder, projection_head):
        super().__init__()
        self.encoder = encoder
        self.projection_head = projection_head

    def train_step(self, data):
        images, labels = data
        
        # Create two augmented versions of the same batch
        aug_img_1 = apply_augmentation(images)
        aug_img_2 = apply_augmentation(images)
        
        with tf.GradientTape() as tape:
            # Get embeddings for both views
            p1 = self.projection_head(self.encoder(aug_img_1))
            p2 = self.projection_head(self.encoder(aug_img_2))
            
            # Concatenate and calculate loss
            loss = self.compiled_loss(labels, tf.concat([p1, p2], axis=0))

        # Update weights
        trainable_vars = self.encoder.trainable_variables + self.projection_head.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        return {"loss": loss}

Fine-Tuning for Classification in Python Keras

Once the encoder is trained, we freeze it and add a linear classifier on top. This is where we see the real power of Supervised Contrastive Learning.

The features learned during the SupCon phase are much more robust. I often find that I need fewer epochs to reach high accuracy during this fine-tuning stage.

def build_classifier(encoder, num_classes=10):
    # Freeze the encoder weights
    encoder.trainable = False
    
    # Create the final classification model
    inputs = layers.Input(shape=(32, 32, 3))
    features = encoder(inputs)
    outputs = layers.Dense(num_classes, activation="softmax")(features)
    
    return keras.Model(inputs, outputs)

classifier = build_classifier(encoder)
classifier.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

Evaluate Model Performance with Python Keras

Finally, we evaluate our model on the test set. I always look at the confusion matrix to see if the model is still confusing similar-looking items.

In my experience, the clusters formed by SupCon are much tighter. This leads to a model that performs better on “out-of-distribution” data or slightly noisy images.

# Evaluate the classifier on the test dataset
loss, accuracy = classifier.evaluate(x_test / 255.0, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")

You can see the output in the screenshot below.

Supervised Contrastive Learning in Python Keras

In this tutorial, I showed you how to move beyond standard classification and use Supervised Contrastive Learning in Python Keras. It is a powerful technique that can significantly boost your model’s ability to learn meaningful features.

I have found that while it requires more setup and computation time, the results in production environments are well worth the effort. You can use this approach for various image-based tasks where class separation is a challenge.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.