Recently, while mentoring a group of data science students, I was asked how to quickly build a Convolutional Neural Network (CNN) in Python Keras to classify handwritten digits using the MNIST dataset.
I realized that although the MNIST dataset is one of the most common starting points for deep learning, many beginners still find it confusing to set up a CNN properly in Keras.
So, in this tutorial, I’ll share my firsthand experience of building a simple yet powerful MNIST ConvNet in Python Keras. I’ll walk you through each step, from loading data to training and evaluating the model.
What is the MNIST Dataset?
The MNIST dataset is a collection of 70,000 grayscale images of handwritten digits (0–9). Each image is 28×28 pixels, making it small and easy to process.
It’s widely used for testing and benchmarking deep learning models because it’s simple yet effective for understanding how neural networks classify visual patterns.
In this tutorial, we’ll use Python Keras to train a CNN that can recognize these digits with over 99% accuracy.
Use Keras for CNNs in Python
Keras is one of the most beginner-friendly deep learning frameworks. It’s built on top of TensorFlow, which means it combines flexibility with simplicity.
I’ve been working with Keras for over four years, and I can confidently say that its clean syntax and modular design make it perfect for experimenting with CNNs.
With just a few lines of Python code, you can define, train, and evaluate a convolutional neural network.
Step 1 – Import Required Python Libraries
Before we start, we need to import all the required Python libraries. Here’s the first step of our Python Keras ConvNet project.
# Importing libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import numpy as npWe’re using TensorFlow’s Keras API, which is the recommended way to build deep learning models in modern Python environments. Matplotlib is used for visualization, and NumPy helps with numerical operations.
Step 2 – Load and Explore the MNIST Dataset in Python
Keras provides the MNIST dataset built in, so you don’t need to download anything manually.
Let’s load the dataset and explore its structure.
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
# Display dataset shapes
print("Training data shape:", train_images.shape)
print("Testing data shape:", test_images.shape)You’ll see that the training set has 60,000 images, and the test set has 10,000 images. Each image is a 28×28 grayscale image representing a digit from 0 to 9.
Step 3 – Preprocess the Data for CNN in Python
Neural networks perform better when input data is normalized. We’ll reshape and scale our pixel values between 0 and 1.
# Reshape and normalize
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
# Convert labels to categorical
train_labels = keras.utils.to_categorical(train_labels)
test_labels = keras.utils.to_categorical(test_labels)Reshaping ensures that the data has a channel dimension (1 for grayscale), which CNNs require. Normalization speeds up training and improves model accuracy.
Step 4 – Build a Simple MNIST ConvNet in Python Keras
Now comes the exciting part: building the CNN model. We’ll create a simple architecture with two convolutional layers followed by pooling, dropout, and dense layers.
# Build the CNN model
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
])
# Display model summary
model.summary()This CNN uses ReLU activation for non-linearity and softmax for classification. The dropout layer reduces overfitting, which is common in small datasets like MNIST.
Step 5 – Compile the CNN Model in Python
Before training, we need to compile the model. We’ll use the Adam optimizer, which is efficient and works well with most deep learning tasks.
# Compile the model
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])Here, we’re using categorical crossentropy because our labels are one-hot encoded. Accuracy is used as the evaluation metric.
Step 6 – Train the MNIST ConvNet Model in Python
Now, let’s train the model using our training data. We’ll run it for 5 epochs, which is enough to achieve high accuracy on MNIST.
# Train the model
history = model.fit(train_images, train_labels,
epochs=5,
batch_size=64,
validation_split=0.1)During training, Keras will display the loss and accuracy for both training and validation sets. You’ll notice that even after a few epochs, the model achieves over 98% accuracy.
Step 7 – Evaluate the CNN Model on Test Data
Once the model is trained, we can evaluate it on the test dataset to see how well it generalizes.
# Evaluate on test data
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test Accuracy: {test_acc * 100:.2f}%")You should get a test accuracy of around 99%, which is excellent for such a simple CNN. This shows how powerful even a basic Keras ConvNet can be in Python.
Step 8 – Visualize Training Performance in Python
It’s always helpful to visualize how your model performed during training. We can plot the accuracy and loss curves to analyze the model’s learning behavior.
# Plot training history
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Model Accuracy')
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('Model Loss')
plt.show()These plots help identify if the model is overfitting or underfitting. In our case, the training and validation curves should align closely, indicating a well-trained model.
Step 9 – Make Predictions Using the Trained CNN in Python
After training, you can use the model to predict new digits.
Let’s visualize a few predictions.
# Make predictions
predictions = model.predict(test_images)
# Display sample predictions
for i in range(5):
plt.imshow(test_images[i].reshape(28, 28), cmap='gray')
plt.title(f"Predicted: {np.argmax(predictions[i])}")
plt.axis('off')
plt.show()This step is useful for testing your model with real-world handwritten digits. You can even extend it to recognize custom digits drawn using a touchscreen or scanned image.
Step 10 – Save and Load the Model in Python Keras
Finally, let’s save our trained model so we can reuse it later without retraining.
# Save the model
model.save('mnist_cnn_model.h5')
# Load the model
loaded_model = keras.models.load_model('mnist_cnn_model.h5')You can see the output in the screenshot below.



Saving and loading models is a best practice in Python Keras projects, especially when deploying them in production. This makes it easy to reuse your trained model for inference or further training.
Alternative Method – Use Functional API in Python Keras
If you prefer more flexibility, you can build the same CNN using the Keras Functional API.
Here’s a quick example.
inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, (3, 3), activation='relu')(inputs)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Conv2D(64, (3, 3), activation='relu')(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10, activation='softmax')(x)
functional_model = keras.Model(inputs, outputs)
functional_model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
functional_model.summary()The Functional API is great when you need to build complex architectures such as multi-input or multi-output models.
For beginners, though, the Sequential API (used earlier) is simpler and easier to follow.
Conclusion
Building a Simple MNIST ConvNet in Python Keras is one of the best ways to start learning deep learning. With just a few lines of Python code, we created a model that recognizes handwritten digits with almost perfect accuracy.
If you’re new to CNNs, I recommend experimenting with different hyperparameters, such as the number of filters, dropout rate, or optimizer, to see how they affect performance. Once you’re comfortable, you can move on to more advanced datasets like CIFAR-10 or Fashion MNIST.
You may also like to read:
- Save a Keras Model with a Custom Layer in Python
- Traffic Signs Recognition Using CNN and Keras in Python
- Emotion Classification using CNN in Python with Keras
- Keras Image Classification: Fine-Tuning EfficientNet

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.