PyTorch MNIST – Complete Tutorial

While working on a project that required classifying handwritten digits, I found the MNIST dataset to be perfect for this task. Using PyTorch made implementing neural networks surprisingly simple.

In this article, I’ll walk you through creating, training, and testing a neural network on the MNIST dataset using PyTorch. We’ll start with the basics and gradually build up to a working model.

Whether you are new to deep learning or want to refresh your PyTorch skills, this tutorial provides hands-on experience with a fundamental image classification task.

So let’s get in!

Method 1 – Set Up Your Environment

Before we start coding, we need to set up our environment with the necessary packages:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

For this tutorial, you’ll need PyTorch and torchvision installed. If you don’t have them yet, you can install them using pip:

pip install torch torchvision

Read Create PyTorch Empty Tensor

Method 2 – Load and Prepare the MNIST Dataset

MNIST contains 60,000 training images and 10,000 test images of handwritten digits (0-9). Here’s how to load and prepare the dataset:

# Define transformations for the images
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize with MNIST mean and std
])

# Load the training dataset
train_dataset = torchvision.datasets.MNIST(
    root='./data',  # Where to store the dataset
    train=True,     # This is training data
    download=True,  # Download if not present
    transform=transform  # Apply transformations
)

# Load the test dataset
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

# Create data loaders for batch processing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Let’s visualize some of the images to get a feel for the data:

# Function to display images
def show_images(images, labels):
    fig, axes = plt.subplots(2, 5, figsize=(10, 4))
    axes = axes.flatten()

    for i in range(10):
        axes[i].imshow(images[i].reshape(28, 28), cmap='gray')
        axes[i].set_title(f"Label: {labels[i]}")
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

# Get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Show images
show_images(images[:10], labels[:10])

Check out the PyTorch Stack Tutorial

Method 3 – Create a Neural Network Model

Now let’s create a simple neural network for classifying the digits:

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

# Initialize the model
model = SimpleNN()
print(model)

This network consists of:

  1. A flatten layer to convert 2D images to 1D vectors
  2. Three fully connected (linear) layers
  3. ReLU activation functions between the layers

Method 4 – Train the Neural Network

Here’s how to train the neural network:

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

# Training loop
num_epochs = 5
train_losses = []

for epoch in range(num_epochs):
    running_loss = 0.0

    for i, (inputs, labels) in enumerate(train_loader):
        # Move tensors to the configured device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
            train_losses.append(running_loss/100)
            running_loss = 0.0

print('Training finished!')

I find it helpful to visualize the training process by plotting the loss:

plt.figure(figsize=(10, 5))
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Training Steps (x100)')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

I executed the above example code and added the screenshot below.

mnist pytorch

Read Use PyTorch Cat function

Method 5 – Evaluate the Model

After training, we need to evaluate how well our model performs on unseen data:

# Evaluate the model
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy on the test set: {accuracy:.2f}%')

Let’s visualize some predictions to see how our model is doing:

# Function to show predictions
def show_predictions(images, labels, preds):
    fig, axes = plt.subplots(2, 5, figsize=(10, 4))
    axes = axes.flatten()

    for i in range(10):
        axes[i].imshow(images[i].cpu().reshape(28, 28), cmap='gray')
        color = 'green' if preds[i] == labels[i] else 'red'
        axes[i].set_title(f"Pred: {preds[i]}, True: {labels[i]}", color=color)
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

# Get predictions for some test images
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

outputs = model(images)
_, preds = torch.max(outputs, 1)

# Show predictions
show_predictions(images[:10], labels[:10], preds[:10])

Read PyTorch Resize Images

Method 6 – Save and Load Your Model

It’s good practice to save your trained model so you can use it later without retraining:

# Save the model
torch.save(model.state_dict(), 'mnist_model.pth')
print("Model saved!")

# Load the model (for future use)
loaded_model = SimpleNN()
loaded_model.load_state_dict(torch.load('mnist_model.pth'))
loaded_model.to(device)
loaded_model.eval()
print("Model loaded!")

Method 7 – Use a Convolutional Neural Network (CNN)

While our simple neural network works reasonably well, convolutional neural networks (CNNs) are typically better for image processing tasks:

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Initialize the CNN model
cnn_model = CNN().to(device)

I executed the above example code and added the screenshot below.

pytorch mnist
pytorch mnist

You can train and evaluate this CNN model using the same code as for the simple neural network. CNNs typically achieve higher accuracy on the MNIST dataset, often exceeding 99%.

I hope you found this tutorial helpful! With these steps, you’ve learned how to create, train, and evaluate neural networks using PyTorch on the MNIST dataset. This serves as a great foundation for more complex computer vision tasks.

The techniques we’ve covered – data loading, model creation, training loops, and evaluation – are fundamental to almost all deep learning projects. You can build on this knowledge to tackle more challenging datasets and create more sophisticated neural network architectures.

You may like to read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.