While working on a project that required classifying handwritten digits, I found the MNIST dataset to be perfect for this task. Using PyTorch made implementing neural networks surprisingly simple.
In this article, I’ll walk you through creating, training, and testing a neural network on the MNIST dataset using PyTorch. We’ll start with the basics and gradually build up to a working model.
Whether you are new to deep learning or want to refresh your PyTorch skills, this tutorial provides hands-on experience with a fundamental image classification task.
So let’s get in!
Method 1 – Set Up Your Environment
Before we start coding, we need to set up our environment with the necessary packages:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as npFor this tutorial, you’ll need PyTorch and torchvision installed. If you don’t have them yet, you can install them using pip:
pip install torch torchvisionRead Create PyTorch Empty Tensor
Method 2 – Load and Prepare the MNIST Dataset
MNIST contains 60,000 training images and 10,000 test images of handwritten digits (0-9). Here’s how to load and prepare the dataset:
# Define transformations for the images
transform = transforms.Compose([
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize((0.1307,), (0.3081,)) # Normalize with MNIST mean and std
])
# Load the training dataset
train_dataset = torchvision.datasets.MNIST(
root='./data', # Where to store the dataset
train=True, # This is training data
download=True, # Download if not present
transform=transform # Apply transformations
)
# Load the test dataset
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transform
)
# Create data loaders for batch processing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)Let’s visualize some of the images to get a feel for the data:
# Function to display images
def show_images(images, labels):
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
axes = axes.flatten()
for i in range(10):
axes[i].imshow(images[i].reshape(28, 28), cmap='gray')
axes[i].set_title(f"Label: {labels[i]}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
# Get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)
# Show images
show_images(images[:10], labels[:10])Check out the PyTorch Stack Tutorial
Method 3 – Create a Neural Network Model
Now let’s create a simple neural network for classifying the digits:
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28*28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# Initialize the model
model = SimpleNN()
print(model)This network consists of:
- A flatten layer to convert 2D images to 1D vectors
- Three fully connected (linear) layers
- ReLU activation functions between the layers
Method 4 – Train the Neural Network
Here’s how to train the neural network:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)
# Training loop
num_epochs = 5
train_losses = []
for epoch in range(num_epochs):
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
# Move tensors to the configured device
inputs, labels = inputs.to(device), labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass and optimize
loss.backward()
optimizer.step()
# Print statistics
running_loss += loss.item()
if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
train_losses.append(running_loss/100)
running_loss = 0.0
print('Training finished!')I find it helpful to visualize the training process by plotting the loss:
plt.figure(figsize=(10, 5))
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Training Steps (x100)')
plt.ylabel('Loss')
plt.grid(True)
plt.show()I executed the above example code and added the screenshot below.

Method 5 – Evaluate the Model
After training, we need to evaluate how well our model performs on unseen data:
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy on the test set: {accuracy:.2f}%')Let’s visualize some predictions to see how our model is doing:
# Function to show predictions
def show_predictions(images, labels, preds):
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
axes = axes.flatten()
for i in range(10):
axes[i].imshow(images[i].cpu().reshape(28, 28), cmap='gray')
color = 'green' if preds[i] == labels[i] else 'red'
axes[i].set_title(f"Pred: {preds[i]}, True: {labels[i]}", color=color)
axes[i].axis('off')
plt.tight_layout()
plt.show()
# Get predictions for some test images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, preds = torch.max(outputs, 1)
# Show predictions
show_predictions(images[:10], labels[:10], preds[:10])Method 6 – Save and Load Your Model
It’s good practice to save your trained model so you can use it later without retraining:
# Save the model
torch.save(model.state_dict(), 'mnist_model.pth')
print("Model saved!")
# Load the model (for future use)
loaded_model = SimpleNN()
loaded_model.load_state_dict(torch.load('mnist_model.pth'))
loaded_model.to(device)
loaded_model.eval()
print("Model loaded!")Method 7 – Use a Convolutional Neural Network (CNN)
While our simple neural network works reasonably well, convolutional neural networks (CNNs) are typically better for image processing tasks:
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(-1, 64 * 7 * 7)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Initialize the CNN model
cnn_model = CNN().to(device)I executed the above example code and added the screenshot below.


You can train and evaluate this CNN model using the same code as for the simple neural network. CNNs typically achieve higher accuracy on the MNIST dataset, often exceeding 99%.
I hope you found this tutorial helpful! With these steps, you’ve learned how to create, train, and evaluate neural networks using PyTorch on the MNIST dataset. This serves as a great foundation for more complex computer vision tasks.
The techniques we’ve covered – data loading, model creation, training loops, and evaluation – are fundamental to almost all deep learning projects. You can build on this knowledge to tackle more challenging datasets and create more sophisticated neural network architectures.
You may like to read:

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.