I’ve found that scaling Vision Transformers (ViT) often leads to significant training instability.
Standard ViT architectures tend to saturate or diverge when you add too many layers, which can be quite frustrating during model development.
Recently, I started using Class Attention Image Transformers (CaiT), which introduces LayerScale to handle these deep architectural challenges effectively.
In this tutorial, I will show you how to implement CaiT with LayerScale using Keras to build robust image classification models.
CaiT Architecture in Keras
CaiT separates the processing of image patches from the class embedding, which allows the model to focus on features before making a classification.
In my experience, this separation prevents the “class token” from interfering with the early-stage self-attention layers of the transformer.
import tensorflow as tf
from tensorflow.keras import layers, models, LRScheduler
import numpy as np
# Initializing a basic CaiT configuration for a US-based retail dataset
IMAGE_SIZE = 224
PATCH_SIZE = 16
NUM_LAYERS = 12
NUM_HEADS = 8
PROJECTION_DIM = 192Implement LayerScale in Keras
LayerScale is a simple yet powerful technique that initializes the output of each residual block with small diagonal values.
I have found that this prevents the signal from exploding in very deep networks, making the training process much smoother from the start.
class LayerScale(layers.Layer):
def __init__(self, init_values, projection_dim, **kwargs):
super().__init__(**kwargs)
self.gamma = tf.Variable(init_values * tf.ones((projection_dim,)))
def call(self, x):
# Multiplying the input by the learnable diagonal matrix
return x * self.gammaCreate the Multi-Head Self-Attention Block
The self-attention block is the heart of the transformer, where the model learns spatial relationships between different parts of the image.
I prefer using the built-in Keras MultiHeadAttention layer because it is highly optimized for GPU performance during training.
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def attention_block(x, projection_dim, num_heads, dropout, init_values):
# Implementing the residual connection with LayerScale
res = x
x = layers.LayerNormalization(epsilon=1e-6)(x)
x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout)(x, x)
x = LayerScale(init_values, projection_dim)(x)
x = layers.Add()([res, x])
# Feed-forward network with LayerScale
res = x
x = layers.LayerNormalization(epsilon=1e-6)(x)
x = mlp(x, hidden_units=[projection_dim * 4, projection_dim], dropout_rate=dropout)
x = LayerScale(init_values, projection_dim)(x)
return layers.Add()([res, x])Build the Class Attention Layer
Class attention layers are unique to CaiT as they only update the class token while keeping the image patch embeddings frozen.
In my projects involving complex datasets like US medical imagery, this method significantly reduces the computational overhead of the final layers.
class ClassAttention(layers.Layer):
def __init__(self, projection_dim, num_heads, dropout, init_values, **kwargs):
super().__init__(**kwargs)
self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout)
self.norm = layers.LayerNormalization(epsilon=1e-6)
self.ls = LayerScale(init_values, projection_dim)
def call(self, x, cls_token):
# Concatenating class token with patch embeddings for cross-attention
joined = tf.concat([cls_token, x], axis=1)
joined_norm = self.norm(joined)
# Querying the class token against all patches
cls_query = self.norm(cls_token)
attn_out = self.attn(query=cls_query, value=joined_norm, key=joined_norm)
return cls_token + self.ls(attn_out)Patch Encoding for Image Data
Before feeding images into the transformer, we must break them down into smaller squares called patches.
I use a convolutional layer to perform this patch extraction as it helps the model capture local textures more effectively than simple reshaping.
def create_patch_embeddings(inputs, patch_size, projection_dim):
# Using a Conv2D layer to create patches and project them simultaneously
patches = layers.Conv2D(filters=projection_dim, kernel_size=patch_size,
strides=patch_size, padding="VALID")(inputs)
batch_size = tf.shape(patches)[0]
num_patches = (IMAGE_SIZE // patch_size) ** 2
return tf.reshape(patches, (batch_size, num_patches, projection_dim))Define the Full CaiT Model Architecture
Now we combine the patch embeddings, standard attention layers, and class attention layers into one cohesive Keras model.
This architecture uses “Talking Heads” attention implicitly through the Keras API, providing better information flow across different attention heads.
def build_cait_model(input_shape, num_classes):
inputs = layers.Input(shape=input_shape)
# Step 1: Patching and Positional Encoding
x = create_patch_embeddings(inputs, PATCH_SIZE, PROJECTION_DIM)
# Step 2: SA (Self-Attention) stages
init_values = 1e-5
for _ in range(NUM_LAYERS):
x = attention_block(x, PROJECTION_DIM, NUM_HEADS, 0.1, init_values)
# Step 3: CA (Class-Attention) stages
cls_token = tf.Variable(tf.zeros((1, 1, PROJECTION_DIM)))
cls_token = tf.tile(cls_token, [tf.shape(x)[0], 1, 1])
for _ in range(2): # Usually 2 layers of CA are sufficient
cls_token = ClassAttention(PROJECTION_DIM, NUM_HEADS, 0.1, init_values)(x, cls_token)
# Step 4: Classification Head
cls_token = layers.LayerNormalization(epsilon=1e-6)(cls_token)
cls_token = tf.squeeze(cls_token, axis=1)
outputs = layers.Dense(num_classes, activation="softmax")(cls_token)
return models.Model(inputs, outputs)
model = build_cait_model((IMAGE_SIZE, IMAGE_SIZE, 3), 10)
model.summary()Compile and Training with Custom Schedulers
For transformers, I’ve noticed that a learning rate warm-up is essential to prevent the gradients from collapsing in the first few epochs.
I typically use the AdamW optimizer or Adam with a custom decay to ensure the LayerScale weights converge appropriately.
# Compiling the Keras model for a classification task
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
# Example: Training on a mock dataset representing US flora/fauna images
# history = model.fit(train_dataset, validation_data=val_dataset, epochs=50)Evaluate Model Performance
After training, I always check the attention maps to ensure the class token is actually looking at the correct features in the image.
CaiT models usually show a much clearer separation of objects from the background compared to standard ViT models I’ve built.
def evaluate_cait_performance(model, test_data):
# Running evaluation on the test set to verify accuracy
results = model.evaluate(test_data)
print(f"Test Loss: {results[0]}, Test Accuracy: {results[1]}")
return resultsYou can see the output in the screenshot below.

In this article, I showed you how to implement Class Attention Image Transformers with LayerScale in Keras to build more stable and deeper vision models.
I have found that using LayerScale is a game-changer when you want to train models with more than 12 layers without hitting a performance ceiling.
You may also read:
- Deep Learning Stability with Gradient Centralization in Python Keras
- Image Tokenization in Vision Transformers with Keras
- Knowledge Distillation in Keras
- Fix the Train-Test Resolution Discrepancy in Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.