Implement Class Attention Image Transformers (CaiT) with LayerScale in Keras

I’ve found that scaling Vision Transformers (ViT) often leads to significant training instability.

Standard ViT architectures tend to saturate or diverge when you add too many layers, which can be quite frustrating during model development.

Recently, I started using Class Attention Image Transformers (CaiT), which introduces LayerScale to handle these deep architectural challenges effectively.

In this tutorial, I will show you how to implement CaiT with LayerScale using Keras to build robust image classification models.

CaiT Architecture in Keras

CaiT separates the processing of image patches from the class embedding, which allows the model to focus on features before making a classification.

In my experience, this separation prevents the “class token” from interfering with the early-stage self-attention layers of the transformer.

import tensorflow as tf
from tensorflow.keras import layers, models, LRScheduler
import numpy as np

# Initializing a basic CaiT configuration for a US-based retail dataset
IMAGE_SIZE = 224
PATCH_SIZE = 16
NUM_LAYERS = 12
NUM_HEADS = 8
PROJECTION_DIM = 192

Implement LayerScale in Keras

LayerScale is a simple yet powerful technique that initializes the output of each residual block with small diagonal values.

I have found that this prevents the signal from exploding in very deep networks, making the training process much smoother from the start.

class LayerScale(layers.Layer):
    def __init__(self, init_values, projection_dim, **kwargs):
        super().__init__(**kwargs)
        self.gamma = tf.Variable(init_values * tf.ones((projection_dim,)))

    def call(self, x):
        # Multiplying the input by the learnable diagonal matrix
        return x * self.gamma

Create the Multi-Head Self-Attention Block

The self-attention block is the heart of the transformer, where the model learns spatial relationships between different parts of the image.

I prefer using the built-in Keras MultiHeadAttention layer because it is highly optimized for GPU performance during training.

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

def attention_block(x, projection_dim, num_heads, dropout, init_values):
    # Implementing the residual connection with LayerScale
    res = x
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout)(x, x)
    x = LayerScale(init_values, projection_dim)(x)
    x = layers.Add()([res, x])
    
    # Feed-forward network with LayerScale
    res = x
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = mlp(x, hidden_units=[projection_dim * 4, projection_dim], dropout_rate=dropout)
    x = LayerScale(init_values, projection_dim)(x)
    return layers.Add()([res, x])

Build the Class Attention Layer

Class attention layers are unique to CaiT as they only update the class token while keeping the image patch embeddings frozen.

In my projects involving complex datasets like US medical imagery, this method significantly reduces the computational overhead of the final layers.

class ClassAttention(layers.Layer):
    def __init__(self, projection_dim, num_heads, dropout, init_values, **kwargs):
        super().__init__(**kwargs)
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout)
        self.norm = layers.LayerNormalization(epsilon=1e-6)
        self.ls = LayerScale(init_values, projection_dim)

    def call(self, x, cls_token):
        # Concatenating class token with patch embeddings for cross-attention
        joined = tf.concat([cls_token, x], axis=1)
        joined_norm = self.norm(joined)
        
        # Querying the class token against all patches
        cls_query = self.norm(cls_token)
        attn_out = self.attn(query=cls_query, value=joined_norm, key=joined_norm)
        
        return cls_token + self.ls(attn_out)

Patch Encoding for Image Data

Before feeding images into the transformer, we must break them down into smaller squares called patches.

I use a convolutional layer to perform this patch extraction as it helps the model capture local textures more effectively than simple reshaping.

def create_patch_embeddings(inputs, patch_size, projection_dim):
    # Using a Conv2D layer to create patches and project them simultaneously
    patches = layers.Conv2D(filters=projection_dim, kernel_size=patch_size, 
                            strides=patch_size, padding="VALID")(inputs)
    batch_size = tf.shape(patches)[0]
    num_patches = (IMAGE_SIZE // patch_size) ** 2
    return tf.reshape(patches, (batch_size, num_patches, projection_dim))

Define the Full CaiT Model Architecture

Now we combine the patch embeddings, standard attention layers, and class attention layers into one cohesive Keras model.

This architecture uses “Talking Heads” attention implicitly through the Keras API, providing better information flow across different attention heads.

def build_cait_model(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)
    
    # Step 1: Patching and Positional Encoding
    x = create_patch_embeddings(inputs, PATCH_SIZE, PROJECTION_DIM)
    
    # Step 2: SA (Self-Attention) stages
    init_values = 1e-5
    for _ in range(NUM_LAYERS):
        x = attention_block(x, PROJECTION_DIM, NUM_HEADS, 0.1, init_values)
        
    # Step 3: CA (Class-Attention) stages
    cls_token = tf.Variable(tf.zeros((1, 1, PROJECTION_DIM)))
    cls_token = tf.tile(cls_token, [tf.shape(x)[0], 1, 1])
    
    for _ in range(2): # Usually 2 layers of CA are sufficient
        cls_token = ClassAttention(PROJECTION_DIM, NUM_HEADS, 0.1, init_values)(x, cls_token)
    
    # Step 4: Classification Head
    cls_token = layers.LayerNormalization(epsilon=1e-6)(cls_token)
    cls_token = tf.squeeze(cls_token, axis=1)
    outputs = layers.Dense(num_classes, activation="softmax")(cls_token)
    
    return models.Model(inputs, outputs)

model = build_cait_model((IMAGE_SIZE, IMAGE_SIZE, 3), 10)
model.summary()

Compile and Training with Custom Schedulers

For transformers, I’ve noticed that a learning rate warm-up is essential to prevent the gradients from collapsing in the first few epochs.

I typically use the AdamW optimizer or Adam with a custom decay to ensure the LayerScale weights converge appropriately.

# Compiling the Keras model for a classification task
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

# Example: Training on a mock dataset representing US flora/fauna images
# history = model.fit(train_dataset, validation_data=val_dataset, epochs=50)

Evaluate Model Performance

After training, I always check the attention maps to ensure the class token is actually looking at the correct features in the image.

CaiT models usually show a much clearer separation of objects from the background compared to standard ViT models I’ve built.

def evaluate_cait_performance(model, test_data):
    # Running evaluation on the test set to verify accuracy
    results = model.evaluate(test_data)
    print(f"Test Loss: {results[0]}, Test Accuracy: {results[1]}")
    return results

You can see the output in the screenshot below.

Implement Class Attention Image Transformers (CaiT) with LayerScale in Keras

In this article, I showed you how to implement Class Attention Image Transformers with LayerScale in Keras to build more stable and deeper vision models.

I have found that using LayerScale is a game-changer when you want to train models with more than 12 layers without hitting a performance ceiling.

You may also read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.