Working on a deep learning project where I needed to deploy an image classification model on a mobile device. The challenge was clear: I needed something lightweight yet powerful enough to handle complex image data efficiently.
After exploring several architectures, I came across MobileViT, a mobile-friendly Transformer-based model that combines the strengths of Convolutional Neural Networks (CNNs) and Vision Transformers (ViTs). This hybrid model offers the accuracy of Transformers with the efficiency of CNNs, perfect for mobile and edge devices.
In this Python tutorial, I’ll show you how to build a mobile-friendly Transformer-based image classification model in Keras from scratch. I’ll walk you through the code, explain each step in simple terms, and share some practical tips based on my experience.
What is a Mobile-Friendly Transformer Model?
A mobile-friendly Transformer is designed to bring the power of attention mechanisms (used in Transformers) to devices with limited resources like smartphones.
Traditional Vision Transformers are accurate but computationally heavy. MobileViT solves this by combining convolutional blocks (for local feature extraction) with Transformer blocks (for global context understanding).
In simple terms, CNNs handle the details, while Transformers handle the big picture. The result is a balanced model that performs well on both accuracy and speed, ideal for real-world mobile image classification tasks.
Set Up the Python Environment
Before we start coding, make sure you have the following Python libraries installed:
pip install tensorflow keras numpy matplotlibI always recommend creating a virtual environment for your Python projects to keep dependencies clean and organized.
Import Required Libraries
We’ll begin by importing the required Python libraries.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as pltThese imports cover everything we need, from building the model to visualizing training results.
Method 1 – Build a Mobile-Friendly Transformer (MobileViT) from Scratch
This is the main part of our tutorial. We’ll build a MobileViT-like model using Keras layers. I’ll break the model into smaller functions to make it easier to understand and modify.
Step 1: Define the Convolutional Block
This block helps the model learn local image features efficiently.
def conv_block(x, filters, kernel_size=3, strides=1):
x = layers.Conv2D(filters, kernel_size, strides=strides, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
return xI like to keep my convolutional blocks simple, Conv2D, BatchNorm, and ReLU activation. This combination works great for feature extraction.
Step 2: Define the Transformer Encoder Block
Transformers capture long-range dependencies, which CNNs often miss. Here, we’ll define a compact Transformer encoder.
def transformer_encoder(x, projection_dim, num_heads):
# Layer normalization
x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
# Multi-head self-attention
attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(x_norm, x_norm)
# Skip connection
x = layers.Add()([x, attention_output])
# Feed-forward network
x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
ffn = keras.Sequential([
layers.Dense(projection_dim * 2, activation='relu'),
layers.Dense(projection_dim),
])
x = layers.Add()([x, ffn(x_norm)])
return xThis block uses multi-head self-attention to learn relationships between different parts of an image.
Step 3: Combine CNN and Transformer Blocks
Now, let’s combine both approaches into a MobileViT block.
def mobilevit_block(x, filters, transformer_dim, num_heads):
# Local representation
local_features = conv_block(x, filters)
# Flatten spatial dimensions
b, h, w, c = local_features.shape
x_flattened = layers.Reshape((h * w, c))(local_features)
# Transformer representation
x_transformed = transformer_encoder(x_flattened, transformer_dim, num_heads)
# Reshape back to image-like format
x_reshaped = layers.Reshape((h, w, c))(x_transformed)
# Combine local and global features
x_out = conv_block(x_reshaped, filters)
return x_outThis function merges CNN and Transformer logic, the essence of a mobile-friendly Transformer.
Step 4: Build the Complete Model
Now, we’ll stack everything together into a full image classification model.
def build_mobilevit(input_shape=(128, 128, 3), num_classes=10):
inputs = keras.Input(shape=input_shape)
# Initial convolution
x = conv_block(inputs, 32, 3, 2)
x = conv_block(x, 64, 3, 2)
# MobileViT blocks
x = mobilevit_block(x, 64, transformer_dim=64, num_heads=4)
x = mobilevit_block(x, 96, transformer_dim=96, num_heads=4)
# Global pooling and classification
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = keras.Model(inputs, outputs, name='MobileViT_Keras')
return modelThis is our final MobileViT model, compact, efficient, and ready for mobile deployment.
Step 5: Compile and Train the Model
Let’s compile and train our model using a small dataset like CIFAR-10.
# Load dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Build model
model = build_mobilevit(input_shape=(32, 32, 3), num_classes=10)
# Compile
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train
history = model.fit(x_train, y_train, validation_split=0.1, epochs=10, batch_size=64)Training this model on a GPU takes just a few minutes. You’ll notice that accuracy improves steadily, thanks to the hybrid architecture.
Step 6: Evaluate and Visualize the Results
Let’s check how well our model performs.
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {test_acc:.2f}")
# Plot training history
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Training and Validation Accuracy')
plt.show()You can refer to the screenshot below to see the output.

This gives you a clear picture of how the model is performing over time.
Method 2 – Use a Pretrained MobileViT Model in Keras
If you don’t want to build from scratch, you can use KerasCV or TensorFlow Hub to load a pretrained MobileViT model.
Here’s how you can do it:
from keras_cv.models import MobileViT
# Load pretrained MobileViT model
model = MobileViT.from_preset("mobilevit_xx_small", input_shape=(128, 128, 3), num_classes=10)
# Compile and train
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, validation_split=0.1, epochs=5, batch_size=64)This method is faster and ideal for production environments where you need quick deployment.
Optimize the Model for Mobile Deployment
Once trained, you can convert the model for mobile use using TensorFlow Lite.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model
with open('mobilevit_model.tflite', 'wb') as f:
f.write(tflite_model)This converts your Keras model into a lightweight .tflite format suitable for Android or iOS apps.
Practical Tips from My Experience
- Always resize your input images to a smaller dimension (like 128×128) to keep inference fast.
- Use mixed precision training in Python to speed up training on GPUs.
- Fine-tune pretrained models instead of training from scratch for better results.
- For mobile deployment, test your model on real devices to check latency.
Building a mobile-friendly Transformer-based model for image classification in Keras is easier than it sounds. With a balanced mix of CNNs and Transformers, you can achieve high accuracy even on resource-constrained devices.
I’ve personally used this approach in several real-world Python projects, and it consistently delivers great results, both in performance and efficiency.
If you’re planning to deploy deep learning models on mobile apps, MobileViT or similar architectures should definitely be part of your toolkit.
You may also like to read:
- Build MNIST Convolutional Neural Network in Python Keras
- Image Classification with Vision Transformer in Keras
- Classification Using Attention-Based Deep Multiple Instance Learning (MIL) in Keras
- Image Classification Using Modern MLP Models in Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.