When I first started training deep learning models for image classification, I often relied on the standard Cross-Entropy loss. It worked well enough, but I noticed the models struggled when the classes were visually similar.
I discovered that Supervised Contrastive Learning (SupCon) is a game-changer for these scenarios. It helps the model learn to pull similar images together and push different ones apart in the feature space.
In this tutorial, I will show you how to implement Supervised Contrastive Learning using Python Keras. We will move beyond basic examples and build something robust that handles complex data.
What is Supervised Contrastive Learning?
Supervised Contrastive Learning is a training technique that uses label information to group similar samples more effectively. Unlike self-supervised learning, it knows which images belong to the same category.
By using this approach in Python Keras, you can create more generalized features. This is particularly useful when you have a dataset where the differences between classes are very subtle.
Set Up Your Python Keras Environment
Before we dive into the logic, we need to ensure our environment is ready. I always make sure to have the latest version of TensorFlow and Keras installed to avoid compatibility issues.
You will need to import several modules to handle the data augmentation and the custom loss function. Here is the initial setup I use for most of my projects.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# Verify the versions
print(f"TensorFlow version: {tf.__version__}")Prepare the Dataset with Python Keras
For this example, imagine we are working with a dataset of retail products from a major US department store. We want to distinguish between different types of apparel accurately.
I prefer using the tf.data API because it is incredibly efficient for handling large batches. We will normalize the images and prepare them for the augmentation pipeline.
def prepare_dataset(images, labels, batch_size=128):
# Normalize pixel values to [0, 1] range
images = images.astype("float32") / 255.0
# Create a tf.data.Dataset object
ds = tf.data.Dataset.from_tensor_slices((images, labels))
ds = ds.shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
return ds
# Loading sample data (using CIFAR-10 as a placeholder for our retail data)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
train_ds = prepare_dataset(x_train, y_train)Implement Data Augmentation in Python Keras
Data augmentation is the heart of contrastive learning. We need to create different “views” of the same image so the model learns that a rotated or flipped shirt is still a shirt.
I usually wrap my augmentation layers into a sequential model. This makes it very easy to apply the same transformations across different training stages.
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
]
)
def apply_augmentation(images):
# Apply random transformations to the input batch
return data_augmentation(images)Build the Encoder Network with Python Keras
The encoder is the part of the model that extracts features from the images. I often use a ResNet-based architecture because it provides a solid balance between speed and performance.
In this step, we strip the final classification layer. We only want the high-dimensional vector that represents the image features, often called the “embedding.”
def create_encoder():
# Load a pre-trained ResNet50V2 without the top head
base_model = keras.applications.ResNet50V2(
include_top=False, weights=None, input_shape=(32, 32, 3)
)
# Add a global average pooling layer
inputs = layers.Input(shape=(32, 32, 3))
x = base_model(inputs)
outputs = layers.GlobalAveragePooling2D()(x)
return keras.Model(inputs, outputs, name="encoder")
encoder = create_encoder()Add a Projection Head in Python Keras
The projection head is a small MLP that maps the encoder’s output to a lower-dimensional space. I’ve found that this significantly improves the quality of the learned representations.
After training is complete, we usually discard this head and use the encoder’s output for downstream tasks. It acts as a temporary “buffer” for the contrastive loss.
def add_projection_head(encoder):
# Map the encoder output to a 128-dimensional space
inputs = layers.Input(shape=(2048,))
x = layers.Dense(512, activation="relu")(inputs)
outputs = layers.Dense(128)(x)
return keras.Model(inputs, outputs, name="projection_head")
projection_head = add_projection_head(encoder)Define the Supervised Contrastive Loss in Python Keras
The loss function is what makes this “Supervised.” It calculates the similarity between all pairs in a batch and penalizes the model if samples with the same label are far apart.
I implement this using matrix multiplication for efficiency. We use a temperature parameter to control how sharply the loss penalizes negative pairs.
class SupervisedContrastiveLoss(keras.losses.Loss):
def __init__(self, temperature=0.05, name=None):
super().__init__(name=name)
self.temperature = temperature
def call(self, labels, feature_vectors):
# Normalize the feature vectors
feature_vectors = tf.math.l2_normalize(feature_vectors, axis=1)
# Compute logits based on cosine similarity
logits = tf.divide(
tf.matmul(feature_vectors, tf.transpose(feature_vectors)),
self.temperature
)
return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
# Note: In practice, I use a custom implementation of the SupCon formulaTrain the SupCon Model with Python Keras
Now we combine everything into a custom training loop or use the Model.compile API. I prefer creating a custom train_step to have full control over the two-view augmentation.
This process trains the encoder to be extremely good at distinguishing between different product categories. It takes a bit longer than standard training but yields much better results.
class SupConModel(keras.Model):
def __init__(self, encoder, projection_head):
super().__init__()
self.encoder = encoder
self.projection_head = projection_head
def train_step(self, data):
images, labels = data
# Create two augmented versions of the same batch
aug_img_1 = apply_augmentation(images)
aug_img_2 = apply_augmentation(images)
with tf.GradientTape() as tape:
# Get embeddings for both views
p1 = self.projection_head(self.encoder(aug_img_1))
p2 = self.projection_head(self.encoder(aug_img_2))
# Concatenate and calculate loss
loss = self.compiled_loss(labels, tf.concat([p1, p2], axis=0))
# Update weights
trainable_vars = self.encoder.trainable_variables + self.projection_head.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {"loss": loss}Fine-Tuning for Classification in Python Keras
Once the encoder is trained, we freeze it and add a linear classifier on top. This is where we see the real power of Supervised Contrastive Learning.
The features learned during the SupCon phase are much more robust. I often find that I need fewer epochs to reach high accuracy during this fine-tuning stage.
def build_classifier(encoder, num_classes=10):
# Freeze the encoder weights
encoder.trainable = False
# Create the final classification model
inputs = layers.Input(shape=(32, 32, 3))
features = encoder(inputs)
outputs = layers.Dense(num_classes, activation="softmax")(features)
return keras.Model(inputs, outputs)
classifier = build_classifier(encoder)
classifier.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])Evaluate Model Performance with Python Keras
Finally, we evaluate our model on the test set. I always look at the confusion matrix to see if the model is still confusing similar-looking items.
In my experience, the clusters formed by SupCon are much tighter. This leads to a model that performs better on “out-of-distribution” data or slightly noisy images.
# Evaluate the classifier on the test dataset
loss, accuracy = classifier.evaluate(x_test / 255.0, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")You can see the output in the screenshot below.

In this tutorial, I showed you how to move beyond standard classification and use Supervised Contrastive Learning in Python Keras. It is a powerful technique that can significantly boost your model’s ability to learn meaningful features.
I have found that while it requires more setup and computation time, the results in production environments are well worth the effort. You can use this approach for various image-based tasks where class separation is a challenge.
You may read:
- Knowledge Distillation for Vision Transformers in Keras
- Focal Modulation vs Self-Attention in Keras
- Image Classification Using Keras Forward-Forward Algorithm
- Implement Masked Image Modeling with Keras Autoencoders

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.