Explore Vision Transformer (ViT) Representations in Keras

As a Keras developer who has spent the last four years building computer vision models, I have always been fascinated by how Vision Transformers (ViT) “see” the world compared to traditional CNNs.

When I first transitioned from ResNet to ViT, I struggled to understand how these global self-attention mechanisms actually processed my image data.

In this tutorial, I will show you exactly how I investigate Vision Transformer representations in Keras to ensure my models are learning the right features.

Extract Patch Embeddings in Keras

The first step in any Vision Transformer pipeline is breaking an image into patches and projecting them into an embedding space.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

def get_patch_embeddings(image, patch_size=16, embed_dim=768):
    # Reshape image into patches and project
    image_size = image.shape[1]
    num_patches = (image_size // patch_size) ** 2
    
    # Create the patch encoder layer
    projection = layers.Conv2D(filters=embed_dim, kernel_size=patch_size, 
                               strides=patch_size, padding="valid")
    
    # Simulate an input image (e.g., 224x224 RGB)
    patches = projection(image)
    reshaped_patches = tf.reshape(patches, (1, num_patches, embed_dim))
    
    return reshaped_patches

# Example: Loading a sample image array
sample_img = np.random.normal(size=(1, 224, 224, 3)).astype('float32')
embeddings = get_patch_embeddings(sample_img)
print(f"Embedding shape: {embeddings.shape}")

You can refer to the screenshot below to see the output.

Explore Vision Transformer (ViT) Representations in Keras

I use this method to verify if the model is correctly preserving the spatial information of an image, like a satellite view of a Chicago suburb, before it enters the transformer layers.

Visualize Self-Attention Maps in Keras

Visualizing attention maps is the best way to see which parts of an image the ViT is focusing on when making a classification decision.

def visualize_attention_maps(model, image):
    # Get the output of the last self-attention layer
    attention_layer = model.get_layer("multi_head_attention_layer")
    
    # Create a sub-model to output attention scores
    intermediate_model = keras.Model(inputs=model.input, outputs=attention_layer.output)
    
    # Predict to get attention weights
    # Note: Simplified representation of attention score extraction
    attention_scores = intermediate_model.predict(image)
    
    # Average across all heads for a global view
    avg_attention = np.mean(attention_scores, axis=1)
    return avg_attention

sample_img = np.random.rand(1, 8, 16).astype("float32")
heatmap = visualize_attention_maps(model, sample_img)
print(heatmap.shape)

You can refer to the screenshot below to see the output.

C:\Users\GradyArchie\Downloads\images\Explore Vision Transformer (ViT) Representations Keras.jpg

I frequently use this technique to debug why a model might be misidentifying a California redwood tree or a New York City taxi by looking at the heatmaps generated by the heads.

Analyze Positional Encodings in Keras

Since Transformers don’t inherently understand the order of patches, positional encodings are added to give the model a sense of “where” things are.

def analyze_positional_embeddings(model):
    # Extract the positional embedding weight matrix
    pos_emb_layer = model.get_layer("pos_embedding")
    pos_weights = pos_emb_layer.get_weights()[0]
    
    # Calculate cosine similarity between positions
    dot_product = np.dot(pos_weights, pos_weights.T)
    norm = np.linalg.norm(pos_weights, axis=1)
    similarity = dot_product / np.outer(norm, norm)
    
    return similarity
similarity_matrix = analyze_positional_embeddings(model)
print(similarity_matrix.shape)

You can refer to the screenshot below to see the output.

Keras Explore Vision Transformer (ViT) Representations

I analyze these encodings to ensure that the model understands that a skyscraper’s top shouldn’t be at the bottom of a Manhattan skyline photo.

Prob Intermediate MLP Representations in Keras

The Multi-Layer Perceptron (MLP) blocks within the Transformer encoder refine the features extracted by the attention mechanism.

def probe_transformer_blocks(model, image, block_index=5):
    # Access a specific Transformer encoder block
    layer_name = f"transformer_block_{block_index}"
    intermediate_layer = model.get_layer(layer_name)
    
    # Create a functional model for probing
    probe_model = keras.Model(inputs=model.input, outputs=intermediate_layer.output)
    
    # Get the high-dimensional representation
    features = probe_model.predict(image)
    return features

# Example: Get features from the 5th block of the ViT
# block_features = probe_transformer_blocks(vit_model, sample_img)

I use probing to extract these intermediate representations, which is helpful when I want to use a pre-trained ViT as a feature extractor for a custom logistics app.

Measure Representation Similarity with CKA in Keras

Centered Kernel Alignment (CKA) is a powerful metric I use to compare how similar the representations of two different Keras models are.

def calculate_cka_similarity(feat1, feat2):
    # Pre-process features for CKA (linear CKA implementation)
    feat1 = feat1 - np.mean(feat1, axis=0)
    feat2 = feat2 - np.mean(feat2, axis=0)
    
    dot_prod_matrix = np.linalg.norm(np.dot(feat1.T, feat2))**2
    norm_feat1 = np.linalg.norm(np.dot(feat1.T, feat1))
    norm_feat2 = np.linalg.norm(np.dot(feat2.T, feat2))
    
    return dot_prod_matrix / (norm_feat1 * norm_feat2)

# Usage: Compare features from two different layers or models
# similarity_score = calculate_cka_similarity(block_features_1, block_features_2)

I find this particularly useful when comparing a lightweight MobileViT model against a large ViT-Base to see if the smaller model is capturing the same hierarchy of features.

Investigate Class Token (CLS) Evolution in Keras

The Class Token (CLS) is a special vector that aggregates information from all other patches to perform the final classification.

def track_cls_token(model, image):
    cls_outputs = []
    
    # Iterate through transformer blocks to find the CLS token state
    for i in range(len(model.layers)):
        if "transformer_block" in model.layers[i].name:
            intermediate_model = keras.Model(inputs=model.input, outputs=model.layers[i].output)
            out = intermediate_model.predict(image)
            # The CLS token is usually the first element in the sequence
            cls_outputs.append(out[:, 0, :])
            
    return cls_outputs

# cls_history = track_cls_token(vit_model, sample_img)

I track how this token changes as it passes through the layers to see at which point the model “decides” it is looking at a Ford F-150 versus a Chevy Silverado.

Fine-Tuning Vision Transformer Heads in Keras

Once I understand the representations, I often need to replace the classification head for a specific task, such as identifying local crop diseases in Iowa.

def fine_tune_vit_head(base_model, num_classes):
    # Freeze the transformer backbone
    base_model.trainable = False
    
    # Add a custom classification head
    inputs = layers.Input(shape=(224, 224, 3))
    x = base_model(inputs, training=False)
    # Target the CLS token or Global Average Pooling
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    
    custom_model = keras.Model(inputs, outputs)
    custom_model.compile(optimizer="adam", loss="categorical_crossentropy")
    
    return custom_model

# final_model = fine_tune_vit_head(vit_base, 10)

This method shows how I freeze the transformer backbone and train only the representation-to-label mapping.

In this tutorial, I showed you several methods I use to investigate Vision Transformer representations in Keras.

I started by looking at patch embeddings and then moved to more complex analyses like CKA similarity and CLS token tracking.

By using these techniques, you can gain a much better understanding of how your models process visual information.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.