Image Tokenization in Vision Transformers with Keras

In my four years of working with Keras, I’ve realized that moving from traditional CNNs to Vision Transformers (ViT) is a massive shift.

The most confusing part for many developers I mentor is how we actually turn a standard image into a sequence of tokens that a Transformer can understand.

In this tutorial, I will show you exactly how I handle tokenization in Keras, using real-world datasets like California housing satellite imagery or retail product photos.

What is Tokenization in Keras Vision Transformers?

Tokenization is the process of breaking an image into smaller, manageable “patches” that act like words in a sentence.

In Keras, we don’t process pixels individually; we group them into fixed-size grids to capture spatial information effectively.

Method 1: Use Keras Layers to Create Image Patches

This is my go-to method because it uses the built-in tf.nn.extract_patches function, which is incredibly fast during training.

I prefer this approach when I am building a standard ViT architecture because it integrates directly into the Keras functional API.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class PatchMaker(layers.Layer):
    def __init__(self, patch_size):
        super(PatchMaker, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        # Calculate patch dimensions for a batch of images
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

# Example usage with a 224x224 image (like a NYC real estate photo)
image_input = tf.random.normal([1, 224, 224, 3])
patch_layer = PatchMaker(patch_size=16)
output_tokens = patch_layer(image_input)
print(f"Tokenized Shape: {output_tokens.shape}") 

You can see the output in the screenshot below.

Keras Image Tokenization in Vision Transformers

Method 2: Implement Linear Projection of Patches in Keras

Once you have the patches, you need to project them into a vector space so the Transformer can learn the relationships between them.

I always use a Dense layer immediately after the patching layer to transform those raw pixel values into meaningful high-dimensional embeddings.

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        # Generate positions for the tokens
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        # Add the learned projection to the position embeddings
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

# Define dimensions for a high-performance US retail classifier
projection_dim = 64
num_patches = (224 // 16) ** 2
encoder = PatchEncoder(num_patches, projection_dim)
final_embeddings = encoder(output_tokens)
print(f"Encoded Tokens Shape: {final_embeddings.shape}")

You can see the output in the screenshot below.

Image Tokenization in Vision Transformers Keras

Method 3: Use a Conv2D Layer for Efficient Keras Tokenization

A clever trick I’ve learned over the years is using a Conv2D layer with a stride equal to the kernel size to tokenize images.

This method is often more efficient on modern GPUs because convolution operations are highly optimized within the Keras backend.

def conv_tokenization_method(inputs, patch_size, projection_dim):
    # This single layer performs both patching and linear projection
    tokens = layers.Conv2D(
        filters=projection_dim,
        kernel_size=patch_size,
        strides=patch_size,
        padding="valid",
        name="conv_tokenization"
    )(inputs)
    
    # Flatten the spatial dimensions into a sequence
    # Example: (None, 14, 14, 64) -> (None, 196, 64)
    batch_size = tf.shape(tokens)[0]
    seq_len = (tokens.shape[1] * tokens.shape[2])
    tokens = tf.reshape(tokens, [batch_size, seq_len, projection_dim])
    return tokens

# Building a quick Keras model snippet for image classification
input_shape = (224, 224, 3)
inputs = layers.Input(shape=input_shape)
tokens = conv_tokenization_method(inputs, 16, 64)
model = keras.Model(inputs=inputs, outputs=tokens)
model.summary()

You can see the output in the screenshot below.

Image Tokenization in Vision Transformers with Keras

Method 4: Implement Shifted Patch Tokenization in Keras

If you are working with smaller datasets, like identifying specific tree species in Oregon forests, standard patches might lose detail at the edges.

I use “Shifted Patching” to allow the model to see overlapping areas, which significantly boosts accuracy in Keras ViT models.

class ShiftedPatchTokenization(layers.Layer):
    def __init__(self, patch_size, projection_dim):
        super().__init__()
        self.patch_size = patch_size
        self.projection = layers.Dense(projection_dim)

    def call(self, x):
        # Create shifted versions of the image
        shift_1 = x[:, self.patch_size//2:, self.patch_size//2:, :]
        # Crop original to match shift (for demonstration simplicity)
        # In a full model, we concatenate multiple shifted views
        
        # We perform standard patching on the shifted input
        patches = tf.image.extract_patches(
            images=x,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        return self.projection(patches)

# Initializing the shifted method for a Keras vision pipeline
shifted_layer = ShiftedPatchTokenization(patch_size=16, projection_dim=128)
# This would be followed by your Transformer blocks

Method 5: Visualize Keras Tokens for Debugging

I always tell my students that if you can’t see your tokens, you don’t know if your model is learning the right features.

This helper function allows you to visualize the exact patches that your Keras model is “seeing” before they are flattened into vectors.

import matplotlib.pyplot as plt
import numpy as np

def visualize_keras_patches(image, patch_size):
    # Ensure image is 4D for the patch extractor
    image_batch = tf.expand_dims(image, 0)
    patches = tf.image.extract_patches(
        images=image_batch,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding="VALID",
    )
    
    # Reshape for plotting
    n = int(np.sqrt(patches.shape[3] / 3))
    patches = tf.reshape(patches, [patches.shape[1] * patches.shape[2], patch_size, patch_size, 3])
    
    plt.figure(figsize=(8, 8))
    for i, patch in enumerate(patches[:16]): # Show first 16 patches
        ax = plt.subplot(4, 4, i + 1)
        plt.imshow(patch.numpy().astype("uint8"))
        plt.axis("off")
    plt.show()

# Testing with a sample array (simulating a US road sign image)
sample_img = tf.random.uniform([224, 224, 3], 0, 255)
visualize_keras_patches(sample_img, 16)

Handle Overlapping Patches in Keras

Sometimes a fixed grid isn’t enough, especially when I’m working on medical imaging or detailed logistics scanning.

You can adjust the strides in the extract_patches method to be smaller than the patch_size to create overlapping tokens.

def overlapping_token_method(images, patch_size, stride_size):
    # Using a smaller stride creates overlap between tokens
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, stride_size, stride_size, 1],
        rates=[1, 1, 1, 1],
        padding="SAME",
    )
    return patches

# Example: 16x16 patches with a stride of 8 (50% overlap)
overlap_tokens = overlapping_token_method(image_input, 16, 8)
print(f"Overlapping Patches Shape: {overlap_tokens.shape}")

In this tutorial, I showed you several different ways to handle image tokenization for Vision Transformers using Keras.

I’ve used these methods in everything from simple classification tasks to complex object detection pipelines.

If you’re just starting, I recommend using the Conv2D approach as it’s the most straightforward and usually results in the fastest training times.

You may also read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.