When I first started exploring Vision Transformers (ViT) in Keras, I was amazed by their performance on large datasets like ImageNet. However, training these models on small datasets can be tricky. From my experience, with the right approach and some tweaks, you can successfully train a ViT for image classification even when the data is limited.
In this article, I’ll walk you through how to build and train a Vision Transformer on small datasets using Keras. I’ll share practical tips and provide a complete working example so you can get started quickly.
What is a Vision Transformer (ViT) in Keras?
Vision Transformers apply the transformer architecture, originally developed for natural language processing, to images. Instead of convolutions, ViT splits images into patches and processes them as sequences, capturing long-range dependencies effectively.
Using Keras, you can implement ViT models that leverage this architecture for image classification tasks. While ViT excels on large datasets, it can be adapted for smaller datasets by careful model design and training strategies.
Prepare Your Small Dataset for ViT in Keras
Small datasets require careful preprocessing to avoid overfitting. Here’s what I do:
- Resize images to a fixed size (e.g., 72×72 pixels) to create uniform patches.
- Normalize pixel values to speed up training convergence.
- Augment data with flips, rotations, and zooms to artificially increase the dataset size.
This simple preprocessing pipeline helps the Vision Transformer learn meaningful features without memorizing the data.
Method 1: Build a Vision Transformer from Scratch in Keras
I like to start by building a ViT model manually to understand its components. The key steps:
- Patch Extraction: Split the image into fixed-size patches.
- Patch Encoding: Flatten patches and add positional embeddings.
- Transformer Encoder: Apply multi-head self-attention and feed-forward layers.
- Classification Head: Use a dense layer to classify the image.
Here’s a concise example:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Parameters
image_size = 72 # Resize images to 72x72
patch_size = 6 # Size of each patch
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_layers = 8
mlp_dim = 128
num_classes = 10 # Example: 10 classes
# Patch extraction layer
class PatchExtractor(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1,1,1,1],
padding='VALID'
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
# Patch encoding layer
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.projection = layers.Dense(projection_dim)
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
def call(self, patches):
positions = tf.range(start=0, limit=num_patches, delta=1)
encoded = self.projection(patches) + self.position_embedding(positions)
return encoded
# Build the ViT model
def create_vit_classifier():
inputs = layers.Input(shape=(image_size, image_size, 3))
patches = PatchExtractor(patch_size)(inputs)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
# Transformer blocks
for _ in range(transformer_layers):
# Layer normalization 1
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Multi-head attention
attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(x1, x1)
# Skip connection 1
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP
mlp_output = layers.Dense(mlp_dim, activation='relu')(x3)
mlp_output = layers.Dense(projection_dim)(mlp_output)
# Skip connection 2
encoded_patches = layers.Add()([mlp_output, x2])
# Classification head
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
outputs = layers.Dense(num_classes, activation='softmax')(representation)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
# Instantiate and compile the model
model = create_vit_classifier()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()I executed the above example code and added the screenshot below.

This model is lightweight enough to train on small datasets but still captures the essential transformer architecture.
Method 2: Use Transfer Learning with Pretrained Vision Transformers in Keras
When working with small datasets, transfer learning is a lifesaver. Instead of training from scratch, I use pretrained ViT models and fine-tune them on my dataset.
Keras and TensorFlow Hub offer pretrained ViT models you can easily load. Here’s how I do it:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers, models
# Load a pretrained ViT model from TF Hub
vit_layer = hub.KerasLayer(
"https://tfhub.dev/sayakpaul/vit_b16_fe/1",
trainable=True)
inputs = layers.Input(shape=(224, 224, 3))
x = vit_layer(inputs)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()I executed the above example code and added the screenshot below.

Fine-tuning lets you leverage knowledge from large datasets, improving performance dramatically on small data.
Training Tips for Vision Transformers on Small Datasets with Keras
From my experience, these tips make a big difference:
- Use data augmentation to increase diversity.
- Apply dropout and regularization to reduce overfitting.
- Train with smaller batch sizes to stabilize updates.
- Use early stopping callbacks to prevent overtraining.
- Consider learning rate schedules to improve convergence.
These simple strategies help you get the most out of your small dataset.
Putting It All Together: Full Training Example
Here’s a complete example training a ViT from scratch on a small dataset like CIFAR-10 resized to 72×72 pixels:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Load and preprocess CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Resize images to 72x72 for patching
x_train = tf.image.resize(x_train, (72, 72)).numpy()
x_test = tf.image.resize(x_test, (72, 72)).numpy()
# Normalize pixel values
x_train = x_train / 255.0
x_test = x_test / 255.0
# Data augmentation
datagen = ImageDataGenerator(
rotation_range=15,
horizontal_flip=True,
zoom_range=0.1,
width_shift_range=0.1,
height_shift_range=0.1
)
datagen.fit(x_train)
# Create ViT model (use the create_vit_classifier function from above)
model = create_vit_classifier()
# Compile the model
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train the model
history = model.fit(datagen.flow(x_train, y_train, batch_size=64),
validation_data=(x_test, y_test),
epochs=30,
callbacks=[tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)])This example shows how to prepare data, build the model, and train it effectively on a small dataset.
Training Vision Transformers on small datasets with Keras is definitely achievable with the right approach. Whether you build your model from scratch or use transfer learning, combining good preprocessing, augmentation, and training strategies will help you get solid results.
If you follow this guide and experiment with hyperparameters, you’ll soon harness the power of ViTs for your image classification projects.
You may also like to read:
- Pneumonia Classification Using TPU in Keras
- Compact Convolutional Transformers in Python with Keras
- Image Classification with ConvMixer in Keras
- Image Classification Using EANet in Python Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.