Recently, I worked on a deep learning project that required me to deploy a pre-trained PyTorch model in a production environment. I encountered challenges loading the PyTorch models correctly, especially when dealing with various model architectures and saving formats.
In this tutorial, I will cover multiple ways to load PyTorch models (using torch.load, state dictionaries, and more).
So let’s get in!
Understand PyTorch Model Saving and Loading
PyTorch offers several methods to save and load models, each with its advantages. Before we look at loading models, it’s important to understand how PyTorch saves them.
Method 1 – Load a Complete Model
The simplest way to load a PyTorch model is to use the torch.load() function. This method loads the entire model, including architecture and parameters.
import torch
import torch.nn as nn
# Define the model class exactly as it was when saved
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# Load the complete model
model = torch.load('my_model.pth', weights_only=False)
# Set to eval mode
model.eval()
# Dummy input for inference
input_data = torch.randn(1, 10)
output = model(input_data)
# Print the output
print(output)I executed the above example code and added the screenshot below.

This method works great when you want to load the same model architecture that was saved. However, it can sometimes cause version compatibility issues if you’re loading a model saved with a different PyTorch version.
Read PyTorch nn Sigmoid Tutorial
Method 2 – Load Using State Dictionaries (Recommended)
The most flexible approach is to save and load only the model’s state dictionary. This contains just the model parameters without the architecture.
import torch
import torch.nn as nn
# Define the model architecture
class MyModelClass(nn.Module):
def __init__(self):
super(MyModelClass, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# Initialize the model
model = MyModelClass()
# Load the saved state dictionary
model.load_state_dict(torch.load('model_state_dict.pth'))
# Set to evaluation mode
model.eval()
# Run a dummy inference
input_data = torch.randn(1, 10)
output = model(input_data)
# Print the output
print(output)I executed the above example code and added the screenshot below.

This method offers better compatibility across PyTorch versions and gives you more flexibility to modify the architecture if needed.
Check out PyTorch TanH
Method 3 – Load Models for Different Devices (CPU/GPU)
When deploying models, you might need to load a model trained on GPU to a CPU-only environment, or vice versa.
import torch
import torch.nn as nn
# Define model class (must match original exactly)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(4, 2)
def forward(self, x):
return self.fc(x)
# Initialize and load weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyModel().to(device)
model.load_state_dict(torch.load('model_state_dict.pth', map_location=device))
model.eval()
# Run a dummy inference
input_data = torch.randn(1, 4).to(device)
output = model(input_data)
print(output)I executed the above example code and added the screenshot below.

The map_location parameter ensures your model loads correctly regardless of where it was trained.
Raed PyTorch Softmax
Method 4 – Load Checkpoint Files (for Training Resumption)
When training large models, you might want to save and load checkpoints to resume training later.
import torch
import torch.optim as optim
from my_model_architecture import MyModelClass
model = MyModelClass()
optimizer = optim.Adam(model.parameters())
# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# Resume training
model.train()This approach saves not just the model parameters but also optimizer state, current epoch, and other training information.
Check out PyTorch Resize Images
Method 5 – Load TorchScript Models
TorchScript models are optimized for production environments and can be loaded using:
import torch
# Load a TorchScript model
model = torch.jit.load('scripted_model.pt')
# Use the model
with torch.no_grad():
output = model(input_tensor)TorchScript models can run without Python dependencies, making them ideal for production deployment.
Method 6 – Load Models from Hugging Face Hub
If you’re working with models from the Hugging Face Hub, you can load them directly:
import torch
from transformers import AutoModel
# Load a model like BERT
model = AutoModel.from_pretrained("bert-base-uncased")
# Load with specific device placement
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModel.from_pretrained("bert-base-uncased").to(device)This is particularly useful for working with popular pre-trained models like BERT, GPT, or other transformers.
Handle Common Loading Issues
Sometimes you might encounter errors when loading models. Here are solutions for common issues:
Check out the PyTorch Stack Tutorial
Missing Keys or Unexpected Keys
If you get errors about missing or unexpected keys in the state dictionary:
# Load with strict=False to ignore missing or unexpected keys
model.load_state_dict(torch.load('model_state_dict.pth'), strict=False)Version Compatibility Issues
For models saved with an older PyTorch version:
import torch
# Load with specific Python pickle module version
model = torch.load('old_model.pth', pickle_module=pickle)Memory Issues with Large Models
For very large models that might not fit in memory:
import torch
# Load on CPU first, then move to GPU in parts
model = torch.load('large_model.pth', map_location='cpu')
model.to('cuda', non_blocking=True) # Move to GPU asynchronouslyBoth methods work great, but I prefer the state dictionary approach (Method 2) for most scenarios. It offers the best flexibility and avoids many compatibility issues that can arise with the complete model loading approach.
I hope you found this article helpful.
Other Python articles you may also like:

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.