When I first started working with self-supervised learning, I found that standard contrastive methods often relied too heavily on simple data augmentations to create “positive” pairs.
I discovered that Nearest Neighbor Contrastive Learning (NNCLR) changes the game by using the nearest neighbor in a support set as a positive sample, which offers much more semantic variety.
In this guide, I will show you how I implement NNCLR using Keras to build robust models that can learn from unlabeled data effectively.
What is NNCLR in Keras Self-Supervised Learning?
NNCLR is a self-supervised method that improves upon SimCLR by sampling positive pairs from a support set of recently seen representations.
This approach allows the model to see different variations of a “concept” rather than just a flipped or cropped version of the same image.
Set Up the Environment for Keras NNCLR
Before we dive into the architecture, I always make sure my environment is loaded with the necessary libraries like TensorFlow and Keras.
I use the following code to initialize my project and ensure I have access to the latest preprocessing layers.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# Setting seeds for reproducibility in Keras
tf.random.set_seed(42)
np.random.seed(42)Prepare the Dataset for Self-Supervised Learning
For this tutorial, I prefer using a dataset like CIFAR-10 because it mimics real-world object recognition challenges we face in industrial applications.
I normalize the data to ensure the neural network converges faster during the contrastive pre-training phase.
def prepare_dataset():
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
return x_train, x_test
train_images, test_images = prepare_dataset()Implement Data Augmentation in Keras NNCLR
Data augmentation is the backbone of contrastive learning because it creates the “views” that the model compares.
I implement a pipeline that includes random cropping, flipping, and color jittering to make the model invariant to these specific changes.
def get_augmentation_model():
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomContrast(0.1),
layers.RandomZoom(0.1),
])
return data_augmentation
augmentation_model = get_augmentation_model()Build the Encoder Architecture with Keras
The encoder is the part of the model that extracts features from the input images; I usually go with a ResNet-like backbone for reliability.
I keep the encoder separate from the projection head so that I can easily use it for downstream tasks later.
def build_encoder():
inputs = layers.Input(shape=(32, 32, 33)) # Adjusted for CIFAR
x = layers.Conv2D(64, (3, 3), activation="relu")(inputs)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Conv2D(128, (3, 3), activation="relu")(x)
x = layers.GlobalAveragePooling2D()(x)
return keras.Model(inputs, x, name="encoder")
encoder = build_encoder()Design the NNCLR Projection Head in Keras
The projection head maps the high-dimensional features to a smaller latent space where the contrastive loss is calculated.
I use a non-linear MLP (Multi-Layer Perceptron) with Batch Normalization to stabilize the learning process in this lower-dimensional space.
def build_projection_head():
inputs = layers.Input(shape=(128,))
x = layers.Dense(128, activation="relu")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Dense(64)(x) # Output dimension for contrastive task
return keras.Model(inputs, x, name="projection_head")
projection_head = build_projection_head()Create the NNCLR Support Set Logic
The support set is a memory bank that stores recent embeddings, allowing the model to find the “nearest neighbor” for a given sample.
I implement this using a custom Keras layer or a buffer that updates after every training step to keep the neighbors fresh.
class NNCLR(keras.Model):
def __init__(self, encoder, projection_head, temperature=0.1, queue_size=1000):
super(NNCLR, self).__init__()
self.encoder = encoder
self.projection_head = projection_head
self.temperature = temperature
self.queue_size = queue_size
self.feature_queue = tf.Variable(
tf.random.normal((queue_size, 64)), trainable=False
)
def nearest_neighbor(self, queries):
# Normalizing for cosine similarity
queries = tf.math.l2_normalize(queries, axis=1)
queue = tf.math.l2_normalize(self.feature_queue, axis=1)
# Matrix multiplication to find similarity
sim = tf.matmul(queries, queue, transpose_b=True)
nn_idx = tf.argmax(sim, axis=1)
return tf.gather(self.feature_queue, nn_idx)Define the Contrastive Loss Function in Keras
The goal of the loss function is to bring the query and its nearest neighbor closer while pushing away other samples in the batch.
I use a version of Cross-Entropy loss over the similarity scores, which is often referred to as NT-Xent loss in self-supervised literature.
def contrastive_loss(projections_1, projections_2, temperature=0.1):
projections_1 = tf.math.l2_normalize(projections_1, axis=1)
projections_2 = tf.math.l2_normalize(projections_2, axis=1)
logits = tf.divide(
tf.matmul(projections_1, projections_2, transpose_b=True),
temperature
)
batch_size = tf.shape(projections_1)[0]
labels = tf.range(batch_size)
return keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)Implement the Custom Train Step for NNCLR
To train the NNCLR model, I override the train_step method in Keras to handle the dual-view input and the support set update.
This allows me to use the standard model.fit() API while performing the complex nearest neighbor lookups under the hood.
@tf.function
def train_step(self, data):
images = data
# Create two augmented views
view1 = augmentation_model(images)
view2 = augmentation_model(images)
with tf.GradientTape() as tape:
z1 = self.projection_head(self.encoder(view1))
z2 = self.projection_head(self.encoder(view2))
# Find nearest neighbor for view 1 from the support set
nn_z1 = self.nearest_neighbor(z1)
loss = contrastive_loss(nn_z1, z2, self.temperature)
# Gradient descent update
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update the queue/support set
self.update_queue(z1)
return {"loss": loss}
# Add update_queue method to the NNCLR class
def update_queue(self, new_embeddings):
self.feature_queue.assign(
tf.concat([new_embeddings, self.feature_queue[:-tf.shape(new_embeddings)[0]]], axis=0)
)Compile and Train the Keras Model
Once the custom logic is in place, I compile the model with an optimizer like Adam and start the training process.
I usually monitor the loss closely to ensure it is decreasing, which indicates the model is successfully learning to group similar features.
nnclr_model = NNCLR(encoder, projection_head)
nnclr_model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001))
# Training on the unlabeled images
nnclr_model.fit(train_images, epochs=20, batch_size=128)Evaluate the Learned Representations
After pre-training, I freeze the encoder weights and attach a simple linear classifier to see how well it performs on a labeled subset.
This “linear probe” is the standard way I verify if the self-supervised features are actually useful for real-world classification.
def evaluate_model(encoder, train_images, train_labels):
encoder.trainable = False
inputs = layers.Input(shape=(32, 32, 3))
features = encoder(inputs)
outputs = layers.Dense(10, activation="softmax")(features)
classifier = keras.Model(inputs, outputs)
classifier.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
classifier.fit(train_images, train_labels, epochs=5, batch_size=128)
return classifier
# classifier_model = evaluate_model(encoder, train_images, train_labels)I executed the above example code and added the screenshot below.

In this tutorial, I showed you how I implemented the NNCLR framework in Keras for self-supervised learning. I covered everything from data augmentation to building a custom training loop with a support set.
By using the nearest neighbor as a positive sample, you can build models that are much more robust than standard contrastive learners. I hope you find this implementation useful in your own Python Keras projects.
You may also read:
- Build a Siamese Network for Image Similarity in Keras
- Image Similarity Estimation with Siamese Networks and Triplet Loss in Keras
- Implement Metric Learning for Image Similarity Search in Keras
- Metric Learning for Image Similarity Search Using TensorFlow Similarity in Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.