While I was debugging a complex CNN architecture, and needed to quickly check the number of parameters and layer shapes. The issue was, PyTorch doesn’t have a built-in summary function like Keras does. After experimenting with different solutions, I found several effective ways to visualize model architecture.
In this article, I’ll share five practical methods to generate model summaries in PyTorch that have saved me countless hours. So let’s get in!
Methods to Visualize Your Neural Networks
Now, I will explain the methods to visualize your neural network.
1 – Use print() Function
The simplest approach is to use Python’s built-in print function:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(16 * 112 * 112, 10)
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = x.view(-1, 16 * 112 * 112)
x = self.fc(x)
return x
model = SimpleModel()
print(model)This gives us:
SimpleModel(
(conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu): ReLU()
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc): Linear(in_features=200704, out_features=10, bias=True)
)You can refer to the screenshot below to see the output.

While this method shows the layer structure, it doesn’t provide parameter counts or shape information. That’s where specialized libraries come in handy.
Read PyTorch nn Conv2d
2 – Use torchsummary Package
The torchsummary package provides a Keras-like summary function. Here’s how to use it:
from torchsummary import summary
# Assuming we're working with 224x224 RGB images
summary(model, (3, 224, 224))This produces a more detailed output:
----------------------------------------------------------------
(venv) PS C:\Users\Public\code\example> & C:/Users/Public/code/example/venv/Scripts/python.exe c:/Users/Public/code/example/example.py
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 16, 224, 224] 448
================================================================
Total params: 2,007,498
Trainable params: 2,007,498
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 13.78
Params size (MB): 7.66
Estimated Total Size (MB): 22.01
----------------------------------------------------------------You can refer to the screenshot below to see the output.

This is much more informative! We can now see:
- Layer names and types
- Output shapes
- Parameter counts per layer
- Total parameter counts
To install torchsummary:
pip install torchsummaryCheck out PyTorch Reshape Tensor
3 – Use torchinfo (Previously torch-summary)
Torchinfo is an improved version of torchsummary with additional features. It’s my go-to solution for model visualization:
from torchinfo import summary
# More detailed summary with batch size information
summary(model, input_size=(32, 3, 224, 224), verbose=2)The output includes:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
SimpleModel -- --
├─Conv2d: 1-1 [32, 16, 224, 224] 448
├─ReLU: 1-2 [32, 16, 224, 224] --
├─MaxPool2d: 1-3 [32, 16, 112, 112] --
├─Linear: 1-4 [32, 10] 2,007,050
==========================================================================================
Total params: 2,007,498
Trainable params: 2,007,498
Non-trainable params: 0
Total mult-adds (G): 27.28
==========================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 154.14
Params size (MB): 7.66
Estimated Total Size (MB): 181.07
==========================================================================================This provides even more useful information:
- Hierarchical view of the model
- Memory usage estimates
- Multiply-add operations
- Input/output sizes in MB
To install torchinfo:
pip install torchinfo4 – Use torchviz for Visual Graphs
Sometimes, a visual representation is worth a thousand lines of text. The torchviz package creates graphical visualizations of your model:
from torchviz import make_dot
# Create a dummy input
x = torch.randn(1, 3, 224, 224)
y = model(x)
# Generate visualization
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("model_architecture", format="png")This generates a computational graph of your model, showing how tensors flow through the network. It’s particularly useful for understanding complex architectures.
To install torchviz:
pip install torchvizRead PyTorch Conv1d
5 – Use PyTorch’s hook_fn for Custom Summaries
For more flexibility, you can create a custom summary using PyTorch’s hook functions:
def hook_fn(module, input, output):
print(f"Module: {module.__class__.__name__}")
print(f"Input shape: {input[0].shape}")
print(f"Output shape: {output.shape}")
print("-------------------------------")
hooks = []
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear, nn.MaxPool2d)):
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
# Run a forward pass
x = torch.randn(1, 3, 224, 224)
output = model(x)
# Remove hooks
for hook in hooks:
hook.remove()This approach lets you track exactly what information you care about. It’s especially useful when debugging specific layers or tracking intermediate activations.
Handle Multiple Inputs
When your model takes multiple inputs, you can adapt the above methods slightly. For example, with torchinfo:
from torchinfo import summary
class MultiInputModel(nn.Module):
def __init__(self):
super(MultiInputModel, self).__init__()
self.image_conv = nn.Conv2d(3, 16, 3)
self.text_embedding = nn.Embedding(1000, 64)
self.fc = nn.Linear(16*222*222 + 64*10, 10)
def forward(self, image, text):
img_feat = self.image_conv(image)
img_feat = img_feat.view(img_feat.size(0), -1)
text_feat = self.text_embedding(text)
text_feat = text_feat.view(text_feat.size(0), -1)
combined = torch.cat([img_feat, text_feat], dim=1)
return self.fc(combined)
model = MultiInputModel()
# Define input sizes for each input
input_size = [(1, 3, 224, 224), (1, 10)]
summary(model, input_size=input_size)This allows you to visualize models with multiple input branches.
Compare the Methods
Here’s a quick comparison of the methods to help you choose:
| Method | Pros | Cons |
|---|---|---|
| print() | No dependencies, always available | Limited information, no parameter counts |
| torchsummary | Easy to use, parameter counts | Limited formatting options |
| torchinfo | Most comprehensive, memory estimates | Slight learning curve |
| torchviz | Visual representation, great for sharing | Requires GraphViz installation |
| custom hooks | Maximum flexibility | Requires more code |
In my experience, I use print() for quick checks, torchinfo for detailed analysis, and torchviz when I need to share model architecture with non-technical team members.
I hope you found this article helpful! Understanding your model architecture is a crucial step in the PyTorch development process, and these tools make it much easier to visualize what’s happening under the hood.
Remember to inspect your model summaries before training to catch dimensional issues early – it can save you hours of debugging time later!
You may also read:

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.