In my four years of working with Keras, I’ve realized that moving from traditional CNNs to Vision Transformers (ViT) is a massive shift.
The most confusing part for many developers I mentor is how we actually turn a standard image into a sequence of tokens that a Transformer can understand.
In this tutorial, I will show you exactly how I handle tokenization in Keras, using real-world datasets like California housing satellite imagery or retail product photos.
What is Tokenization in Keras Vision Transformers?
Tokenization is the process of breaking an image into smaller, manageable “patches” that act like words in a sentence.
In Keras, we don’t process pixels individually; we group them into fixed-size grids to capture spatial information effectively.
Method 1: Use Keras Layers to Create Image Patches
This is my go-to method because it uses the built-in tf.nn.extract_patches function, which is incredibly fast during training.
I prefer this approach when I am building a standard ViT architecture because it integrates directly into the Keras functional API.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class PatchMaker(layers.Layer):
def __init__(self, patch_size):
super(PatchMaker, self).__init__()
self.patch_size = patch_size
def call(self, images):
# Calculate patch dimensions for a batch of images
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
# Example usage with a 224x224 image (like a NYC real estate photo)
image_input = tf.random.normal([1, 224, 224, 3])
patch_layer = PatchMaker(patch_size=16)
output_tokens = patch_layer(image_input)
print(f"Tokenized Shape: {output_tokens.shape}") You can see the output in the screenshot below.

Method 2: Implement Linear Projection of Patches in Keras
Once you have the patches, you need to project them into a vector space so the Transformer can learn the relationships between them.
I always use a Dense layer immediately after the patching layer to transform those raw pixel values into meaningful high-dimensional embeddings.
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
# Generate positions for the tokens
positions = tf.range(start=0, limit=self.num_patches, delta=1)
# Add the learned projection to the position embeddings
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
# Define dimensions for a high-performance US retail classifier
projection_dim = 64
num_patches = (224 // 16) ** 2
encoder = PatchEncoder(num_patches, projection_dim)
final_embeddings = encoder(output_tokens)
print(f"Encoded Tokens Shape: {final_embeddings.shape}")You can see the output in the screenshot below.

Method 3: Use a Conv2D Layer for Efficient Keras Tokenization
A clever trick I’ve learned over the years is using a Conv2D layer with a stride equal to the kernel size to tokenize images.
This method is often more efficient on modern GPUs because convolution operations are highly optimized within the Keras backend.
def conv_tokenization_method(inputs, patch_size, projection_dim):
# This single layer performs both patching and linear projection
tokens = layers.Conv2D(
filters=projection_dim,
kernel_size=patch_size,
strides=patch_size,
padding="valid",
name="conv_tokenization"
)(inputs)
# Flatten the spatial dimensions into a sequence
# Example: (None, 14, 14, 64) -> (None, 196, 64)
batch_size = tf.shape(tokens)[0]
seq_len = (tokens.shape[1] * tokens.shape[2])
tokens = tf.reshape(tokens, [batch_size, seq_len, projection_dim])
return tokens
# Building a quick Keras model snippet for image classification
input_shape = (224, 224, 3)
inputs = layers.Input(shape=input_shape)
tokens = conv_tokenization_method(inputs, 16, 64)
model = keras.Model(inputs=inputs, outputs=tokens)
model.summary()You can see the output in the screenshot below.

Method 4: Implement Shifted Patch Tokenization in Keras
If you are working with smaller datasets, like identifying specific tree species in Oregon forests, standard patches might lose detail at the edges.
I use “Shifted Patching” to allow the model to see overlapping areas, which significantly boosts accuracy in Keras ViT models.
class ShiftedPatchTokenization(layers.Layer):
def __init__(self, patch_size, projection_dim):
super().__init__()
self.patch_size = patch_size
self.projection = layers.Dense(projection_dim)
def call(self, x):
# Create shifted versions of the image
shift_1 = x[:, self.patch_size//2:, self.patch_size//2:, :]
# Crop original to match shift (for demonstration simplicity)
# In a full model, we concatenate multiple shifted views
# We perform standard patching on the shifted input
patches = tf.image.extract_patches(
images=x,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
return self.projection(patches)
# Initializing the shifted method for a Keras vision pipeline
shifted_layer = ShiftedPatchTokenization(patch_size=16, projection_dim=128)
# This would be followed by your Transformer blocksMethod 5: Visualize Keras Tokens for Debugging
I always tell my students that if you can’t see your tokens, you don’t know if your model is learning the right features.
This helper function allows you to visualize the exact patches that your Keras model is “seeing” before they are flattened into vectors.
import matplotlib.pyplot as plt
import numpy as np
def visualize_keras_patches(image, patch_size):
# Ensure image is 4D for the patch extractor
image_batch = tf.expand_dims(image, 0)
patches = tf.image.extract_patches(
images=image_batch,
sizes=[1, patch_size, patch_size, 1],
strides=[1, patch_size, patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
# Reshape for plotting
n = int(np.sqrt(patches.shape[3] / 3))
patches = tf.reshape(patches, [patches.shape[1] * patches.shape[2], patch_size, patch_size, 3])
plt.figure(figsize=(8, 8))
for i, patch in enumerate(patches[:16]): # Show first 16 patches
ax = plt.subplot(4, 4, i + 1)
plt.imshow(patch.numpy().astype("uint8"))
plt.axis("off")
plt.show()
# Testing with a sample array (simulating a US road sign image)
sample_img = tf.random.uniform([224, 224, 3], 0, 255)
visualize_keras_patches(sample_img, 16)Handle Overlapping Patches in Keras
Sometimes a fixed grid isn’t enough, especially when I’m working on medical imaging or detailed logistics scanning.
You can adjust the strides in the extract_patches method to be smaller than the patch_size to create overlapping tokens.
def overlapping_token_method(images, patch_size, stride_size):
# Using a smaller stride creates overlap between tokens
patches = tf.image.extract_patches(
images=images,
sizes=[1, patch_size, patch_size, 1],
strides=[1, stride_size, stride_size, 1],
rates=[1, 1, 1, 1],
padding="SAME",
)
return patches
# Example: 16x16 patches with a stride of 8 (50% overlap)
overlap_tokens = overlapping_token_method(image_input, 16, 8)
print(f"Overlapping Patches Shape: {overlap_tokens.shape}")In this tutorial, I showed you several different ways to handle image tokenization for Vision Transformers using Keras.
I’ve used these methods in everything from simple classification tasks to complex object detection pipelines.
If you’re just starting, I recommend using the Conv2D approach as it’s the most straightforward and usually results in the fastest training times.
You may also read:
- Implement Metric Learning for Image Similarity Search in Keras
- Metric Learning for Image Similarity Search Using TensorFlow Similarity in Keras
- Implement NNCLR in Keras for Self-Supervised Contrastive Learning
- Deep Learning Stability with Gradient Centralization in Python Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.