Implement Metric Learning for Image Similarity Search in Keras

I’ve found that standard classification isn’t always enough. Sometimes, you need to know how “close” two images are rather than just labeling them.

Metric learning allows us to train a model to map images into a multi-dimensional space where similar items sit close together. This is the secret sauce behind facial recognition and product recommendation engines used by major US retailers.

In this tutorial, I will walk you through the practical ways I implement metric learning for image similarity search using Keras and TensorFlow.

Metric Learning in Keras

Metric learning is all about teaching a neural network to calculate a distance metric between different inputs.

Instead of predicting a class, the model learns an embedding, a numerical representation, where the distance between embeddings corresponds to the similarity of the images.

Method 1: Use Siamese Networks with Keras

A Siamese network consists of two identical subnetworks that share the same weights and parameters.

import tensorflow as tf
from tensorflow.keras import layers, Model

def build_siamese_base(input_shape):
    # This base model extracts features from the images
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, (3, 3), activation="relu")(inputs)
    x = layers.MaxPooling2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation="relu")(x)
    return Model(inputs, x)

input_shape = (128, 128, 3)
base_network = build_siamese_base(input_shape)

# Create two inputs for the twin networks
input_a = layers.Input(shape=input_shape)
input_b = layers.Input(shape=input_shape)

# Both inputs pass through the same weights
feat_a = base_network(input_a)
feat_b = base_network(input_b)

# Calculate the Euclidean distance between the two embeddings
distance = tf.norm(feat_a - feat_b, axis=1, keepdims=True)
siamese_model = Model(inputs=[input_a, input_b], outputs=distance)

siamese_model.summary()

You can see the output in the screenshot below.

Implement Metric Learning for Image Similarity Search Keras

I use this method when I want to compare two specific images to see if they belong to the same category, like verifying a digital ID photo.

Implement Contrastive Loss in Keras

Contrastive loss is a popular objective function used to train Siamese networks for image similarity.

def contrastive_loss(y_true, y_pred, margin=1.0):
    # Calculate the loss based on whether the images are similar (1) or different (0)
    sq_pred = tf.square(y_pred)
    margin_sq = tf.square(tf.maximum(margin - y_pred, 0))
    return tf.reduce_mean(y_true * sq_pred + (1 - y_true) * margin_sq)

# Example of compiling the Keras model with this loss
siamese_model.compile(optimizer='adam', loss=contrastive_loss)

It works by minimizing the distance between similar pairs and ensuring the distance for dissimilar pairs is greater than a specific margin.

Method 2: Triplet Loss for Image Similarity Search in Keras

Triplet loss is my preferred method when dealing with large-scale image retrieval systems for e-commerce catalogs.

class TripletLossLayer(layers.Layer):
    def __init__(self, margin=0.5, **kwargs):
        self.margin = margin
        super().__init__(**kwargs)

    def call(self, inputs):
        anchor, positive, negative = inputs
        # Distance between anchor and positive
        pos_dist = tf.reduce_sum(tf.square(anchor - positive), axis=-1)
        # Distance between anchor and negative
        neg_dist = tf.reduce_sum(tf.square(anchor - negative), axis=-1)
        # Final loss value
        loss = tf.maximum(pos_dist - neg_dist + self.margin, 0.0)
        return loss

# Defining the three inputs
in_anc = layers.Input(shape=(128, 128, 3), name="anchor")
in_pos = layers.Input(shape=(128, 128, 3), name="positive")
in_neg = layers.Input(shape=(128, 128, 3), name="negative")

# Extracting embeddings
emb_anc = base_network(in_anc)
emb_pos = base_network(in_pos)
emb_neg = base_network(in_neg)

# Adding the custom loss layer
loss_layer = TripletLossLayer(margin=0.4)([emb_anc, emb_pos, emb_neg])
triplet_model = Model(inputs=[in_anc, in_pos, in_neg], outputs=loss_layer)

You can see the output in the screenshot below.

Keras Implement Metric Learning for Image Similarity Search

It uses three inputs: an Anchor, a Positive (same class as Anchor), and a Negative (different class), pushing the Anchor closer to the Positive and further from the Negative.

Build the Embedding Pipeline in Keras

Once the model is trained, you only need the base subnetwork to generate embeddings for your search index.

# Extract the trained encoder part
trained_encoder = triplet_model.layers[3] # Assuming index 3 is our base_network

def get_image_embedding(img_path):
    # Load and preprocess the image
    img = tf.keras.utils.load_img(img_path, target_size=(128, 128))
    img_array = tf.keras.utils.img_to_array(img) / 255.0
    img_array = tf.expand_dims(img_array, axis=0)
    
    # Predict the vector representation
    embedding = trained_encoder.predict(img_array)
    return embedding
test_embedding = get_image_embedding('test_image.jpg')
print(f"Embedding shape: {test_embedding.shape}")

You can see the output in the screenshot below.

Implement Metric Learning for Image Similarity Search in Keras

In a real-world scenario, such as a real estate app in Los Angeles, you would pre-calculate these for every property photo to allow instant visual searches.

Visualize Embeddings with Keras and PCA

To verify if the metric learning model is actually working, I often visualize the embeddings in a 2D or 3D space.

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np

# Let's assume 'all_embeddings' is a numpy array of vectors from your dataset
def plot_embeddings(all_embeddings, labels):
    pca = PCA(n_components=2)
    reduced_vecs = pca.fit_transform(all_embeddings)
    
    plt.figure(figsize=(10, 7))
    plt.scatter(reduced_vecs[:, 0], reduced_vecs[:, 1], c=labels, cmap='viridis')
    plt.colorbar()
    plt.title("Visualizing Image Similarity Clusters")
    plt.show()

I use Principal Component Analysis (PCA) to reduce the dimensionality of the Keras-generated vectors so I can see if similar images cluster together.

Perform the Image Similarity Search

After generating embeddings for your entire database, you can perform a search by calculating the distance between a query image and your stored vectors.

from sklearn.metrics.pairwise import cosine_similarity

def find_most_similar(query_embedding, database_embeddings, top_k=5):
    # Calculate similarity scores
    scores = cosine_similarity(query_embedding, database_embeddings)
    
    # Get the indices of the highest scores
    best_matches = np.argsort(scores[0])[::-1][:top_k]
    return best_matches

# Example usage
# matches = find_most_similar(new_img_vector, my_indexed_vectors)

I typically use the Cosine Similarity or L2 distance to find the most relevant matches in the search space.

Train Tips for Metric Learning in Keras

Training these models can be tricky because the choice of pairs or triplets (mining) significantly impacts performance.

I recommend starting with “Easy Triplets” to let the model converge early and then moving to “Semi-Hard Triplets” to refine the accuracy.

# Using a higher batch size helps the model see more relationships
# I often use 64 or 128 for metric learning tasks in Keras
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
triplet_model.compile(optimizer=optimizer, loss=None) 
# Note: loss is None because it's calculated inside the custom layer

In this tutorial, I showed you several ways to implement metric learning for image similarity search using Keras.

Whether you are building a tool to find similar fashion items or a system to detect duplicate documents, these techniques are incredibly powerful.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.