As a Keras developer who has spent the last four years building computer vision models, I have always been fascinated by how Vision Transformers (ViT) “see” the world compared to traditional CNNs.
When I first transitioned from ResNet to ViT, I struggled to understand how these global self-attention mechanisms actually processed my image data.
In this tutorial, I will show you exactly how I investigate Vision Transformer representations in Keras to ensure my models are learning the right features.
Extract Patch Embeddings in Keras
The first step in any Vision Transformer pipeline is breaking an image into patches and projecting them into an embedding space.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
def get_patch_embeddings(image, patch_size=16, embed_dim=768):
# Reshape image into patches and project
image_size = image.shape[1]
num_patches = (image_size // patch_size) ** 2
# Create the patch encoder layer
projection = layers.Conv2D(filters=embed_dim, kernel_size=patch_size,
strides=patch_size, padding="valid")
# Simulate an input image (e.g., 224x224 RGB)
patches = projection(image)
reshaped_patches = tf.reshape(patches, (1, num_patches, embed_dim))
return reshaped_patches
# Example: Loading a sample image array
sample_img = np.random.normal(size=(1, 224, 224, 3)).astype('float32')
embeddings = get_patch_embeddings(sample_img)
print(f"Embedding shape: {embeddings.shape}")You can refer to the screenshot below to see the output.

I use this method to verify if the model is correctly preserving the spatial information of an image, like a satellite view of a Chicago suburb, before it enters the transformer layers.
Visualize Self-Attention Maps in Keras
Visualizing attention maps is the best way to see which parts of an image the ViT is focusing on when making a classification decision.
def visualize_attention_maps(model, image):
# Get the output of the last self-attention layer
attention_layer = model.get_layer("multi_head_attention_layer")
# Create a sub-model to output attention scores
intermediate_model = keras.Model(inputs=model.input, outputs=attention_layer.output)
# Predict to get attention weights
# Note: Simplified representation of attention score extraction
attention_scores = intermediate_model.predict(image)
# Average across all heads for a global view
avg_attention = np.mean(attention_scores, axis=1)
return avg_attention
sample_img = np.random.rand(1, 8, 16).astype("float32")
heatmap = visualize_attention_maps(model, sample_img)
print(heatmap.shape)You can refer to the screenshot below to see the output.

I frequently use this technique to debug why a model might be misidentifying a California redwood tree or a New York City taxi by looking at the heatmaps generated by the heads.
Analyze Positional Encodings in Keras
Since Transformers don’t inherently understand the order of patches, positional encodings are added to give the model a sense of “where” things are.
def analyze_positional_embeddings(model):
# Extract the positional embedding weight matrix
pos_emb_layer = model.get_layer("pos_embedding")
pos_weights = pos_emb_layer.get_weights()[0]
# Calculate cosine similarity between positions
dot_product = np.dot(pos_weights, pos_weights.T)
norm = np.linalg.norm(pos_weights, axis=1)
similarity = dot_product / np.outer(norm, norm)
return similarity
similarity_matrix = analyze_positional_embeddings(model)
print(similarity_matrix.shape)You can refer to the screenshot below to see the output.

I analyze these encodings to ensure that the model understands that a skyscraper’s top shouldn’t be at the bottom of a Manhattan skyline photo.
Prob Intermediate MLP Representations in Keras
The Multi-Layer Perceptron (MLP) blocks within the Transformer encoder refine the features extracted by the attention mechanism.
def probe_transformer_blocks(model, image, block_index=5):
# Access a specific Transformer encoder block
layer_name = f"transformer_block_{block_index}"
intermediate_layer = model.get_layer(layer_name)
# Create a functional model for probing
probe_model = keras.Model(inputs=model.input, outputs=intermediate_layer.output)
# Get the high-dimensional representation
features = probe_model.predict(image)
return features
# Example: Get features from the 5th block of the ViT
# block_features = probe_transformer_blocks(vit_model, sample_img)I use probing to extract these intermediate representations, which is helpful when I want to use a pre-trained ViT as a feature extractor for a custom logistics app.
Measure Representation Similarity with CKA in Keras
Centered Kernel Alignment (CKA) is a powerful metric I use to compare how similar the representations of two different Keras models are.
def calculate_cka_similarity(feat1, feat2):
# Pre-process features for CKA (linear CKA implementation)
feat1 = feat1 - np.mean(feat1, axis=0)
feat2 = feat2 - np.mean(feat2, axis=0)
dot_prod_matrix = np.linalg.norm(np.dot(feat1.T, feat2))**2
norm_feat1 = np.linalg.norm(np.dot(feat1.T, feat1))
norm_feat2 = np.linalg.norm(np.dot(feat2.T, feat2))
return dot_prod_matrix / (norm_feat1 * norm_feat2)
# Usage: Compare features from two different layers or models
# similarity_score = calculate_cka_similarity(block_features_1, block_features_2)I find this particularly useful when comparing a lightweight MobileViT model against a large ViT-Base to see if the smaller model is capturing the same hierarchy of features.
Investigate Class Token (CLS) Evolution in Keras
The Class Token (CLS) is a special vector that aggregates information from all other patches to perform the final classification.
def track_cls_token(model, image):
cls_outputs = []
# Iterate through transformer blocks to find the CLS token state
for i in range(len(model.layers)):
if "transformer_block" in model.layers[i].name:
intermediate_model = keras.Model(inputs=model.input, outputs=model.layers[i].output)
out = intermediate_model.predict(image)
# The CLS token is usually the first element in the sequence
cls_outputs.append(out[:, 0, :])
return cls_outputs
# cls_history = track_cls_token(vit_model, sample_img)I track how this token changes as it passes through the layers to see at which point the model “decides” it is looking at a Ford F-150 versus a Chevy Silverado.
Fine-Tuning Vision Transformer Heads in Keras
Once I understand the representations, I often need to replace the classification head for a specific task, such as identifying local crop diseases in Iowa.
def fine_tune_vit_head(base_model, num_classes):
# Freeze the transformer backbone
base_model.trainable = False
# Add a custom classification head
inputs = layers.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
# Target the CLS token or Global Average Pooling
outputs = layers.Dense(num_classes, activation="softmax")(x)
custom_model = keras.Model(inputs, outputs)
custom_model.compile(optimizer="adam", loss="categorical_crossentropy")
return custom_model
# final_model = fine_tune_vit_head(vit_base, 10)This method shows how I freeze the transformer backbone and train only the representation-to-label mapping.
In this tutorial, I showed you several methods I use to investigate Vision Transformer representations in Keras.
I started by looking at patch embeddings and then moved to more complex analyses like CKA similarity and CLS token tracking.
By using these techniques, you can gain a much better understanding of how your models process visual information.
You may read:
- Image Captioning with Keras
- Natural Language Image Search Engine with Keras Dual Encoders
- Ways to Visualize Convolutional Neural Network Filters in Keras
- Keras Model Predictions with Integrated Gradients

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.