How to Load PyTorch Models?

Recently, I worked on a deep learning project that required me to deploy a pre-trained PyTorch model in a production environment. I encountered challenges loading the PyTorch models correctly, especially when dealing with various model architectures and saving formats.

In this tutorial, I will cover multiple ways to load PyTorch models (using torch.load, state dictionaries, and more).

So let’s get in!

Understand PyTorch Model Saving and Loading

PyTorch offers several methods to save and load models, each with its advantages. Before we look at loading models, it’s important to understand how PyTorch saves them.

Method 1 – Load a Complete Model

The simplest way to load a PyTorch model is to use the torch.load() function. This method loads the entire model, including architecture and parameters.

import torch
import torch.nn as nn

# Define the model class exactly as it was when saved
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# Load the complete model
model = torch.load('my_model.pth', weights_only=False)

# Set to eval mode
model.eval()

# Dummy input for inference
input_data = torch.randn(1, 10)
output = model(input_data)

# Print the output
print(output)

I executed the above example code and added the screenshot below.

pytorch model load

This method works great when you want to load the same model architecture that was saved. However, it can sometimes cause version compatibility issues if you’re loading a model saved with a different PyTorch version.

Read PyTorch nn Sigmoid Tutorial

Method 2 – Load Using State Dictionaries (Recommended)

The most flexible approach is to save and load only the model’s state dictionary. This contains just the model parameters without the architecture.

import torch
import torch.nn as nn

# Define the model architecture
class MyModelClass(nn.Module):
    def __init__(self):
        super(MyModelClass, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# Initialize the model
model = MyModelClass()

# Load the saved state dictionary
model.load_state_dict(torch.load('model_state_dict.pth'))

# Set to evaluation mode
model.eval()

# Run a dummy inference
input_data = torch.randn(1, 10)
output = model(input_data)

# Print the output
print(output)

I executed the above example code and added the screenshot below.

torch load

This method offers better compatibility across PyTorch versions and gives you more flexibility to modify the architecture if needed.

Check out PyTorch TanH

Method 3 – Load Models for Different Devices (CPU/GPU)

When deploying models, you might need to load a model trained on GPU to a CPU-only environment, or vice versa.

import torch
import torch.nn as nn

# Define model class (must match original exactly)
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(4, 2)

    def forward(self, x):
        return self.fc(x)

# Initialize and load weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyModel().to(device)
model.load_state_dict(torch.load('model_state_dict.pth', map_location=device))
model.eval()

# Run a dummy inference
input_data = torch.randn(1, 4).to(device)
output = model(input_data)

print(output)

I executed the above example code and added the screenshot below.

torch load model

The map_location parameter ensures your model loads correctly regardless of where it was trained.

Raed PyTorch Softmax

Method 4 – Load Checkpoint Files (for Training Resumption)

When training large models, you might want to save and load checkpoints to resume training later.

import torch
import torch.optim as optim
from my_model_architecture import MyModelClass

model = MyModelClass()
optimizer = optim.Adam(model.parameters())

# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# Resume training
model.train()

This approach saves not just the model parameters but also optimizer state, current epoch, and other training information.

Check out PyTorch Resize Images

Method 5 – Load TorchScript Models

TorchScript models are optimized for production environments and can be loaded using:

import torch

# Load a TorchScript model
model = torch.jit.load('scripted_model.pt')

# Use the model
with torch.no_grad():
    output = model(input_tensor)

TorchScript models can run without Python dependencies, making them ideal for production deployment.

Read Use PyTorch Cat function

Method 6 – Load Models from Hugging Face Hub

If you’re working with models from the Hugging Face Hub, you can load them directly:

import torch
from transformers import AutoModel

# Load a model like BERT
model = AutoModel.from_pretrained("bert-base-uncased")

# Load with specific device placement
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModel.from_pretrained("bert-base-uncased").to(device)

This is particularly useful for working with popular pre-trained models like BERT, GPT, or other transformers.

Handle Common Loading Issues

Sometimes you might encounter errors when loading models. Here are solutions for common issues:

Check out the PyTorch Stack Tutorial

Missing Keys or Unexpected Keys

If you get errors about missing or unexpected keys in the state dictionary:

# Load with strict=False to ignore missing or unexpected keys
model.load_state_dict(torch.load('model_state_dict.pth'), strict=False)

Version Compatibility Issues

For models saved with an older PyTorch version:

import torch

# Load with specific Python pickle module version
model = torch.load('old_model.pth', pickle_module=pickle)

Memory Issues with Large Models

For very large models that might not fit in memory:

import torch

# Load on CPU first, then move to GPU in parts
model = torch.load('large_model.pth', map_location='cpu')
model.to('cuda', non_blocking=True)  # Move to GPU asynchronously

Both methods work great, but I prefer the state dictionary approach (Method 2) for most scenarios. It offers the best flexibility and avoids many compatibility issues that can arise with the complete model loading approach.

I hope you found this article helpful.

Other Python articles you may also like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.