I was working on a Python project where I needed to classify thousands of images quickly and accurately. At first, I tried using a traditional Convolutional Neural Network (CNN), but the accuracy plateaued after a point.
That’s when I decided to explore Vision Transformer (ViT), a model that adapts the Transformer architecture (originally used for NLP) for image classification. In this tutorial, I’ll show you how I implemented an image classification model using Keras and Python with the Vision Transformer architecture.
If you’re familiar with Python and Keras but new to Vision Transformers, don’t worry.
I’ll walk you through everything step by step, from importing data to training and evaluating the model.
What is a Vision Transformer (ViT)?
Before diving into the Python code, let me quickly explain what a Vision Transformer is.
Unlike CNNs that use convolutional filters, ViTs divide an image into small patches and treat each patch as a “token,” similar to words in a sentence.
Each token is then passed through a Transformer encoder that learns relationships between patches using self-attention. This approach allows the model to capture global dependencies in an image more effectively than CNNs.
Use Keras for Vision Transformer in Python
I’ve been using Keras for more than four years, and it’s my go-to deep learning library for Python. It’s simple, flexible, and integrates seamlessly with TensorFlow, making it perfect for implementing complex architectures like ViT.
Keras also provides utilities for loading datasets, preprocessing images, and visualizing results, all of which we’ll use in this tutorial.
So, let’s get started!
Step 1 – Import Required Python Libraries
Before we begin, let’s import the necessary Python libraries. We’ll use TensorFlow’s Keras API along with NumPy and Matplotlib for data handling and visualization.
# Import necessary Python libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
print("TensorFlow version:", tf.__version__)This Python code imports all the essential libraries and prints the TensorFlow version for confirmation. You should see something like TensorFlow version: 2.15.0 if you’re using a recent version.
Step 2 – Load and Prepare the Dataset
For this tutorial, I’ll use the CIFAR-100 dataset, which contains 100 classes of images (each 32×32 pixels). It’s a great dataset for testing image classification models in Python.
# Load CIFAR-100 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
# Normalize pixel values
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
print("Training data shape:", x_train.shape)
print("Test data shape:", x_test.shape)Normalizing pixel values between 0 and 1 helps the Vision Transformer learn faster and more effectively. This is a standard preprocessing step in most Python-based deep learning workflows.
Step 3 – Create Image Patches
Vision Transformers don’t process the entire image at once. Instead, they divide images into small patches that are later flattened and embedded.
# Define patch creation layer
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patchesThis custom Python class breaks each image into small square patches that can later be fed into the Transformer encoder. It’s an essential step in implementing ViT.
Step 4 – Encode Patches Using an Embedding Layer
Once we have patches, we need to embed them into a vector space.
This helps the model understand spatial relationships between patches.
# Define patch encoding layer
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encodedThis Python class adds positional embeddings to each patch, allowing the model to retain spatial information. Without positional encoding, the Transformer wouldn’t know the order of patches.
Step 5 – Build the Vision Transformer Model
Now that we’ve prepared patches and encodings, let’s build the full Vision Transformer model in Keras. We’ll define the Transformer blocks, including Multi-Head Attention and MLP layers.
def create_vit_classifier(
input_shape=(32, 32, 3),
patch_size=4,
num_patches=(32 // 4) ** 2,
projection_dim=64,
transformer_layers=8,
num_heads=4,
mlp_head_units=[2048, 1024],
num_classes=100,
):
inputs = layers.Input(shape=input_shape)
patches = Patches(patch_size)(inputs)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
# Create multiple Transformer blocks
for _ in range(transformer_layers):
# Layer normalization 1
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Multi-head attention
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP block
x3 = layers.Dense(mlp_head_units[0], activation=tf.nn.gelu)(x3)
x3 = layers.Dense(mlp_head_units[1], activation=tf.nn.gelu)(x3)
# Skip connection 2
encoded_patches = layers.Add()([x3, x2])
# Classification head
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
logits = layers.Dense(num_classes)(representation)
# Create the Keras model
model = keras.Model(inputs=inputs, outputs=logits)
return modelThis Python function creates a complete Vision Transformer model using Keras layers.
It includes multiple Transformer blocks and a final classification head.
Step 6 – Compile and Train the Model
Now that our model is ready, let’s compile it using the Adam optimizer and Sparse Categorical Crossentropy loss.
Then, we’ll train it on the CIFAR-100 dataset.
# Create and compile the model
vit_model = create_vit_classifier()
vit_model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
# Train the model
history = vit_model.fit(
x_train, y_train,
batch_size=64,
epochs=10,
validation_split=0.1,
)Training a Vision Transformer in Python can take some time depending on your hardware. If you’re using a GPU, you’ll notice a significant speed improvement.
Step 7 – Evaluate and Visualize Results
Once training is complete, let’s evaluate the model on the test dataset and visualize accuracy trends. This helps us understand how well the model generalizes.
# Evaluate the model
test_loss, test_acc = vit_model.evaluate(x_test, y_test)
print("Test Accuracy:", test_acc)
# Plot training and validation accuracy
plt.plot(history.history["accuracy"], label="Training Accuracy")
plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
plt.title("Vision Transformer Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.show()You can refer to the screenshot below to see the output.

You’ll likely see accuracy around 60–70% after 10 epochs, depending on your system and hyperparameters. With fine-tuning or longer training, you can improve this further.
Alternative: Use Pretrained Vision Transformer Models
If you don’t want to train from scratch, Keras and TensorFlow Hub offer pretrained ViT models for Python developers.
You can load them and fine-tune on your own dataset easily.
# Example of using a pretrained ViT model
import tensorflow_hub as hub
pretrained_model = keras.Sequential([
hub.KerasLayer("https://tfhub.dev/google/vit_base_patch16_224/1"),
layers.Dense(100, activation='softmax')
])
pretrained_model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)Using pretrained models is a great option if you’re working with limited data or want faster results. You can fine-tune them for specific image categories (like wildlife or traffic signs in the U.S.).
Conclusion
In this Python tutorial, I showed you how to build an image classification model using Vision Transformer (ViT) in Keras. We covered everything from dataset loading and patch creation to model training and evaluation.
While CNNs are still powerful, Vision Transformers offer a fresh and highly effective approach to image understanding. With Keras and Python, implementing ViT becomes straightforward and efficient.
If you’re working on any real-world image classification project, whether it’s detecting road signs, classifying retail products, or analyzing satellite imagery, I highly recommend giving Vision Transformers a try.
You may also like to read:
- Traffic Signs Recognition Using CNN and Keras in Python
- Emotion Classification using CNN in Python with Keras
- How to Import TensorFlow Keras in Python
- Build MNIST Convolutional Neural Network in Python Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.