Recently, I was working on a computer vision project where I had to classify hundreds of product images like electronics, clothing, and groceries.
While convolutional neural networks (CNNs) are the usual choice, I wanted to explore something different, modern MLP (Multi-Layer Perceptron) models that have been making waves in deep learning research.
In this Python tutorial, I’ll show you how to perform image classification using modern MLP models in Keras. I’ll walk you through the steps I personally use, from preparing the dataset to training and evaluating the model.
What Are Modern MLP Models?
Traditional MLPs were simple feed-forward neural networks. They worked well for tabular data but struggled with spatial data like images.
However, recent advancements, such as MLP-Mixer, FNet, and gMLP, have shown that carefully designed MLP architectures can achieve competitive results in image classification tasks, even without attention mechanisms.
These models replace convolutional layers with MLP blocks that mix spatial and channel information efficiently.
Use MLP Models for Image Classification in Python
From my experience, MLP-based models offer a few advantages:
- They are lightweight and easier to train compared to CNNs.
- They can achieve comparable accuracy on standard datasets like CIFAR-10 and CIFAR-100.
- They are highly parallelizable, making them ideal for modern GPUs and TPUs.
If you’re looking for a Python-based deep learning approach that’s both simple and powerful, MLP models in Keras are a great choice.
Set Up the Python Environment
Before we start coding, make sure you have the following Python libraries installed.
pip install tensorflow numpy matplotlibThese libraries will allow us to build, train, and visualize our MLP model using Keras, which is included within the TensorFlow library.
Load and Prepare the Dataset
For this example, I’ll use the CIFAR-100 dataset, which contains 60,000 color images (32×32 pixels) across 100 categories.
This dataset is built into Keras, so you can load it easily using the following Python code:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Load CIFAR-100 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
# Normalize pixel values
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
print("Training data shape:", x_train.shape)
print("Test data shape:", x_test.shape)Here, I normalize the pixel values to make the training process faster and more stable.
Build the MLP-Mixer Model in Keras (Python)
Now comes the exciting part: building the MLP-Mixer model. This model uses two types of MLPs: one for mixing spatial information and another for mixing channel information.
Here’s the full Python code to build the MLP-Mixer architecture in Keras:
# Define patch creation layer
def create_patches(images, patch_size):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, patch_size, patch_size, 1],
strides=[1, patch_size, patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
# Define MLP block
def mlp_block(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
# Define Mixer layer
class MixerLayer(layers.Layer):
def __init__(self, num_patches, hidden_units, dropout_rate):
super().__init__()
self.norm1 = layers.LayerNormalization()
self.norm2 = layers.LayerNormalization()
self.mlp1 = mlp_block
self.mlp2 = mlp_block
self.num_patches = num_patches
self.hidden_units = hidden_units
self.dropout_rate = dropout_rate
def call(self, inputs):
x = self.norm1(inputs)
x = tf.transpose(x, perm=[0, 2, 1])
x = self.mlp1(x, self.hidden_units, self.dropout_rate)
x = tf.transpose(x, perm=[0, 2, 1])
x = inputs + x
y = self.norm2(x)
y = self.mlp2(y, self.hidden_units, self.dropout_rate)
return x + y
# Build the MLP-Mixer model
def build_mlp_mixer(input_shape, num_classes, num_blocks=4, patch_size=4, hidden_units=[128, 256], dropout_rate=0.2):
inputs = keras.Input(shape=input_shape)
patches = create_patches(inputs, patch_size)
x = layers.Dense(128)(patches)
for _ in range(num_blocks):
x = MixerLayer(x.shape[1], hidden_units, dropout_rate)(x)
x = layers.GlobalAveragePooling1D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
model = keras.Model(inputs, outputs)
return model
# Initialize model
model = build_mlp_mixer(input_shape=(32, 32, 3), num_classes=100)
model.summary()This Python code defines a modular MLP-Mixer architecture that can handle image data efficiently. Each layer mixes information across spatial and channel dimensions, helping the model learn complex patterns without convolutions.
Train the MLP Model in Python
Once the model is built, let’s compile and train it using the Adam optimizer and categorical cross-entropy loss.
Here’s how I do it:
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
# Train the model
history = model.fit(
x_train, y_train,
validation_data=(x_test, y_test),
batch_size=64,
epochs=10
)This will train the MLP model for 10 epochs. You can adjust the number of epochs depending on your available GPU and desired accuracy.
Evaluate the Model Performance
After training, it’s time to evaluate how well our MLP model performs on unseen data.
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test accuracy: {test_acc * 100:.2f}%")In my experiments, I achieved around 60–65% accuracy after 10 epochs, which is impressive for a non-convolutional model trained on CIFAR-100.
Visualize Predictions in Python
To better understand how the model performs, let’s visualize a few predictions.
import matplotlib.pyplot as plt
import numpy as np
predictions = model.predict(x_test)
predicted_labels = np.argmax(predictions, axis=1)
# Display 5 sample predictions
for i in range(5):
plt.imshow(x_test[i])
plt.title(f"Predicted: {predicted_labels[i]}, Actual: {y_test[i][0]}")
plt.show()You can see the output in the screenshot below.

This small visualization helps identify where the model performs well and where it struggles.
Alternative MLP Models in Keras
If you’re interested in exploring other architectures, Keras also supports FNet and gMLP, which are variations of MLP models that use Fourier transforms or gating mechanisms.
Each model has its own strengths; for example, FNet is extremely fast, while gMLP captures long-range dependencies effectively.
Tips for Better Accuracy
Here are a few Python-based tricks that helped me improve model performance:
- Data Augmentation: Use tf.keras.preprocessing.image.ImageDataGenerator to add variety to your training samples.
- Learning Rate Scheduling: Gradually reduce the learning rate using ReduceLROnPlateau.
- Batch Normalization: Add normalization layers to stabilize training.
- Early Stopping: Prevent overfitting by monitoring validation loss.
These small improvements can make a big difference in real-world image classification projects.
While CNNs still dominate image classification, modern MLP models in Keras are proving to be a strong alternative, especially for Python developers looking for simpler architectures without sacrificing performance.
I’ve personally found MLP-Mixer models to be easy to implement, fast to train, and surprisingly accurate for medium-scale datasets.
If you’re working on an image classification project in Python, I highly recommend giving these MLP models a try. You’ll be amazed at how far simple architectures have come in recent years.
You may like to read:
- Keras Image Classification: Fine-Tuning EfficientNet
- Build MNIST Convolutional Neural Network in Python Keras
- Emotion Classification using CNN in Python with Keras
- How to Install and Set Up Keras in Python

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.