Image Classification with Vision Transformer in Keras

I was working on a Python project where I needed to classify thousands of images quickly and accurately. At first, I tried using a traditional Convolutional Neural Network (CNN), but the accuracy plateaued after a point.

That’s when I decided to explore Vision Transformer (ViT), a model that adapts the Transformer architecture (originally used for NLP) for image classification. In this tutorial, I’ll show you how I implemented an image classification model using Keras and Python with the Vision Transformer architecture.

If you’re familiar with Python and Keras but new to Vision Transformers, don’t worry.
I’ll walk you through everything step by step, from importing data to training and evaluating the model.

What is a Vision Transformer (ViT)?

Before diving into the Python code, let me quickly explain what a Vision Transformer is.
Unlike CNNs that use convolutional filters, ViTs divide an image into small patches and treat each patch as a “token,” similar to words in a sentence.

Each token is then passed through a Transformer encoder that learns relationships between patches using self-attention. This approach allows the model to capture global dependencies in an image more effectively than CNNs.

Use Keras for Vision Transformer in Python

I’ve been using Keras for more than four years, and it’s my go-to deep learning library for Python. It’s simple, flexible, and integrates seamlessly with TensorFlow, making it perfect for implementing complex architectures like ViT.

Keras also provides utilities for loading datasets, preprocessing images, and visualizing results, all of which we’ll use in this tutorial.
So, let’s get started!

Step 1 – Import Required Python Libraries

Before we begin, let’s import the necessary Python libraries. We’ll use TensorFlow’s Keras API along with NumPy and Matplotlib for data handling and visualization.

# Import necessary Python libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

print("TensorFlow version:", tf.__version__)

This Python code imports all the essential libraries and prints the TensorFlow version for confirmation. You should see something like TensorFlow version: 2.15.0 if you’re using a recent version.

Step 2 – Load and Prepare the Dataset

For this tutorial, I’ll use the CIFAR-100 dataset, which contains 100 classes of images (each 32×32 pixels). It’s a great dataset for testing image classification models in Python.

# Load CIFAR-100 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

# Normalize pixel values
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

print("Training data shape:", x_train.shape)
print("Test data shape:", x_test.shape)

Normalizing pixel values between 0 and 1 helps the Vision Transformer learn faster and more effectively. This is a standard preprocessing step in most Python-based deep learning workflows.

Step 3 – Create Image Patches

Vision Transformers don’t process the entire image at once. Instead, they divide images into small patches that are later flattened and embedded.

# Define patch creation layer
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

This custom Python class breaks each image into small square patches that can later be fed into the Transformer encoder. It’s an essential step in implementing ViT.

Step 4 – Encode Patches Using an Embedding Layer

Once we have patches, we need to embed them into a vector space.
This helps the model understand spatial relationships between patches.

# Define patch encoding layer
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

This Python class adds positional embeddings to each patch, allowing the model to retain spatial information. Without positional encoding, the Transformer wouldn’t know the order of patches.

Step 5 – Build the Vision Transformer Model

Now that we’ve prepared patches and encodings, let’s build the full Vision Transformer model in Keras. We’ll define the Transformer blocks, including Multi-Head Attention and MLP layers.

def create_vit_classifier(
    input_shape=(32, 32, 3),
    patch_size=4,
    num_patches=(32 // 4) ** 2,
    projection_dim=64,
    transformer_layers=8,
    num_heads=4,
    mlp_head_units=[2048, 1024],
    num_classes=100,
):
    inputs = layers.Input(shape=input_shape)
    patches = Patches(patch_size)(inputs)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple Transformer blocks
    for _ in range(transformer_layers):
        # Layer normalization 1
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Multi-head attention
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP block
        x3 = layers.Dense(mlp_head_units[0], activation=tf.nn.gelu)(x3)
        x3 = layers.Dense(mlp_head_units[1], activation=tf.nn.gelu)(x3)
        # Skip connection 2
        encoded_patches = layers.Add()([x3, x2])

    # Classification head
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    logits = layers.Dense(num_classes)(representation)

    # Create the Keras model
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

This Python function creates a complete Vision Transformer model using Keras layers.
It includes multiple Transformer blocks and a final classification head.

Step 6 – Compile and Train the Model

Now that our model is ready, let’s compile it using the Adam optimizer and Sparse Categorical Crossentropy loss.
Then, we’ll train it on the CIFAR-100 dataset.

# Create and compile the model
vit_model = create_vit_classifier()
vit_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

# Train the model
history = vit_model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=10,
    validation_split=0.1,
)

Training a Vision Transformer in Python can take some time depending on your hardware. If you’re using a GPU, you’ll notice a significant speed improvement.

Step 7 – Evaluate and Visualize Results

Once training is complete, let’s evaluate the model on the test dataset and visualize accuracy trends. This helps us understand how well the model generalizes.

# Evaluate the model
test_loss, test_acc = vit_model.evaluate(x_test, y_test)
print("Test Accuracy:", test_acc)

# Plot training and validation accuracy
plt.plot(history.history["accuracy"], label="Training Accuracy")
plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
plt.title("Vision Transformer Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

You can refer to the screenshot below to see the output.

Image Classification with Vision Transformer in Keras

You’ll likely see accuracy around 60–70% after 10 epochs, depending on your system and hyperparameters. With fine-tuning or longer training, you can improve this further.

Alternative: Use Pretrained Vision Transformer Models

If you don’t want to train from scratch, Keras and TensorFlow Hub offer pretrained ViT models for Python developers.
You can load them and fine-tune on your own dataset easily.

# Example of using a pretrained ViT model
import tensorflow_hub as hub

pretrained_model = keras.Sequential([
    hub.KerasLayer("https://tfhub.dev/google/vit_base_patch16_224/1"),
    layers.Dense(100, activation='softmax')
])

pretrained_model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

Using pretrained models is a great option if you’re working with limited data or want faster results. You can fine-tune them for specific image categories (like wildlife or traffic signs in the U.S.).

Conclusion

In this Python tutorial, I showed you how to build an image classification model using Vision Transformer (ViT) in Keras. We covered everything from dataset loading and patch creation to model training and evaluation.

While CNNs are still powerful, Vision Transformers offer a fresh and highly effective approach to image understanding. With Keras and Python, implementing ViT becomes straightforward and efficient.

If you’re working on any real-world image classification project, whether it’s detecting road signs, classifying retail products, or analyzing satellite imagery, I highly recommend giving Vision Transformers a try.

You may also like to read:

Leave a Comment

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.