Build a Mobile-Friendly Transformer-Based Model for Image Classification in Keras

Working on a deep learning project where I needed to deploy an image classification model on a mobile device. The challenge was clear: I needed something lightweight yet powerful enough to handle complex image data efficiently.

After exploring several architectures, I came across MobileViT, a mobile-friendly Transformer-based model that combines the strengths of Convolutional Neural Networks (CNNs) and Vision Transformers (ViTs). This hybrid model offers the accuracy of Transformers with the efficiency of CNNs, perfect for mobile and edge devices.

In this Python tutorial, I’ll show you how to build a mobile-friendly Transformer-based image classification model in Keras from scratch. I’ll walk you through the code, explain each step in simple terms, and share some practical tips based on my experience.

What is a Mobile-Friendly Transformer Model?

A mobile-friendly Transformer is designed to bring the power of attention mechanisms (used in Transformers) to devices with limited resources like smartphones.

Traditional Vision Transformers are accurate but computationally heavy. MobileViT solves this by combining convolutional blocks (for local feature extraction) with Transformer blocks (for global context understanding).

In simple terms, CNNs handle the details, while Transformers handle the big picture. The result is a balanced model that performs well on both accuracy and speed, ideal for real-world mobile image classification tasks.

Set Up the Python Environment

Before we start coding, make sure you have the following Python libraries installed:

pip install tensorflow keras numpy matplotlib

I always recommend creating a virtual environment for your Python projects to keep dependencies clean and organized.

Import Required Libraries

We’ll begin by importing the required Python libraries.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

These imports cover everything we need, from building the model to visualizing training results.

Method 1 – Build a Mobile-Friendly Transformer (MobileViT) from Scratch

This is the main part of our tutorial. We’ll build a MobileViT-like model using Keras layers. I’ll break the model into smaller functions to make it easier to understand and modify.

Step 1: Define the Convolutional Block

This block helps the model learn local image features efficiently.

def conv_block(x, filters, kernel_size=3, strides=1):
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

I like to keep my convolutional blocks simple, Conv2D, BatchNorm, and ReLU activation. This combination works great for feature extraction.

Step 2: Define the Transformer Encoder Block

Transformers capture long-range dependencies, which CNNs often miss. Here, we’ll define a compact Transformer encoder.

def transformer_encoder(x, projection_dim, num_heads):
    # Layer normalization
    x_norm = layers.LayerNormalization(epsilon=1e-6)(x)

    # Multi-head self-attention
    attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(x_norm, x_norm)

    # Skip connection
    x = layers.Add()([x, attention_output])

    # Feed-forward network
    x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
    ffn = keras.Sequential([
        layers.Dense(projection_dim * 2, activation='relu'),
        layers.Dense(projection_dim),
    ])
    x = layers.Add()([x, ffn(x_norm)])
    return x

This block uses multi-head self-attention to learn relationships between different parts of an image.

Step 3: Combine CNN and Transformer Blocks

Now, let’s combine both approaches into a MobileViT block.

def mobilevit_block(x, filters, transformer_dim, num_heads):
    # Local representation
    local_features = conv_block(x, filters)

    # Flatten spatial dimensions
    b, h, w, c = local_features.shape
    x_flattened = layers.Reshape((h * w, c))(local_features)

    # Transformer representation
    x_transformed = transformer_encoder(x_flattened, transformer_dim, num_heads)

    # Reshape back to image-like format
    x_reshaped = layers.Reshape((h, w, c))(x_transformed)

    # Combine local and global features
    x_out = conv_block(x_reshaped, filters)

    return x_out

This function merges CNN and Transformer logic, the essence of a mobile-friendly Transformer.

Step 4: Build the Complete Model

Now, we’ll stack everything together into a full image classification model.

def build_mobilevit(input_shape=(128, 128, 3), num_classes=10):
    inputs = keras.Input(shape=input_shape)

    # Initial convolution
    x = conv_block(inputs, 32, 3, 2)
    x = conv_block(x, 64, 3, 2)

    # MobileViT blocks
    x = mobilevit_block(x, 64, transformer_dim=64, num_heads=4)
    x = mobilevit_block(x, 96, transformer_dim=96, num_heads=4)

    # Global pooling and classification
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = keras.Model(inputs, outputs, name='MobileViT_Keras')
    return model

This is our final MobileViT model, compact, efficient, and ready for mobile deployment.

Step 5: Compile and Train the Model

Let’s compile and train our model using a small dataset like CIFAR-10.

# Load dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Build model
model = build_mobilevit(input_shape=(32, 32, 3), num_classes=10)

# Compile
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train
history = model.fit(x_train, y_train, validation_split=0.1, epochs=10, batch_size=64)

Training this model on a GPU takes just a few minutes. You’ll notice that accuracy improves steadily, thanks to the hybrid architecture.

Step 6: Evaluate and Visualize the Results

Let’s check how well our model performs.

test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {test_acc:.2f}")

# Plot training history
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Training and Validation Accuracy')
plt.show()

You can refer to the screenshot below to see the output.

Mobile-Friendly Transformer-Based Model for Image Classification in Keras

This gives you a clear picture of how the model is performing over time.

Method 2 – Use a Pretrained MobileViT Model in Keras

If you don’t want to build from scratch, you can use KerasCV or TensorFlow Hub to load a pretrained MobileViT model.

Here’s how you can do it:

from keras_cv.models import MobileViT

# Load pretrained MobileViT model
model = MobileViT.from_preset("mobilevit_xx_small", input_shape=(128, 128, 3), num_classes=10)

# Compile and train
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, validation_split=0.1, epochs=5, batch_size=64)

This method is faster and ideal for production environments where you need quick deployment.

Optimize the Model for Mobile Deployment

Once trained, you can convert the model for mobile use using TensorFlow Lite.

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model
with open('mobilevit_model.tflite', 'wb') as f:
    f.write(tflite_model)

This converts your Keras model into a lightweight .tflite format suitable for Android or iOS apps.

Practical Tips from My Experience

  • Always resize your input images to a smaller dimension (like 128×128) to keep inference fast.
  • Use mixed precision training in Python to speed up training on GPUs.
  • Fine-tune pretrained models instead of training from scratch for better results.
  • For mobile deployment, test your model on real devices to check latency.

Building a mobile-friendly Transformer-based model for image classification in Keras is easier than it sounds. With a balanced mix of CNNs and Transformers, you can achieve high accuracy even on resource-constrained devices.

I’ve personally used this approach in several real-world Python projects, and it consistently delivers great results, both in performance and efficiency.

If you’re planning to deploy deep learning models on mobile apps, MobileViT or similar architectures should definitely be part of your toolkit.

You may also like to read:

Leave a Comment

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.