PyTorch Model Summary

While I was debugging a complex CNN architecture, and needed to quickly check the number of parameters and layer shapes. The issue was, PyTorch doesn’t have a built-in summary function like Keras does. After experimenting with different solutions, I found several effective ways to visualize model architecture.

In this article, I’ll share five practical methods to generate model summaries in PyTorch that have saved me countless hours. So let’s get in!

Methods to Visualize Your Neural Networks

Now, I will explain the methods to visualize your neural network.

1 – Use print() Function

The simplest approach is to use Python’s built-in print function:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(16 * 112 * 112, 10)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = x.view(-1, 16 * 112 * 112)
        x = self.fc(x)
        return x

model = SimpleModel()
print(model)

This gives us:

SimpleModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=200704, out_features=10, bias=True)
)

You can refer to the screenshot below to see the output.

pytorch model summary

While this method shows the layer structure, it doesn’t provide parameter counts or shape information. That’s where specialized libraries come in handy.

Read PyTorch nn Conv2d

2 – Use torchsummary Package

The torchsummary package provides a Keras-like summary function. Here’s how to use it:

from torchsummary import summary

# Assuming we're working with 224x224 RGB images
summary(model, (3, 224, 224))

This produces a more detailed output:

----------------------------------------------------------------
(venv) PS C:\Users\Public\code\example> & C:/Users/Public/code/example/venv/Scripts/python.exe c:/Users/Public/code/example/example.py
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 224, 224]             448
================================================================
Total params: 2,007,498
Trainable params: 2,007,498
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 13.78
Params size (MB): 7.66
Estimated Total Size (MB): 22.01
----------------------------------------------------------------

You can refer to the screenshot below to see the output.

torchsummary

This is much more informative! We can now see:

  • Layer names and types
  • Output shapes
  • Parameter counts per layer
  • Total parameter counts

To install torchsummary:

pip install torchsummary

Check out PyTorch Reshape Tensor

3 – Use torchinfo (Previously torch-summary)

Torchinfo is an improved version of torchsummary with additional features. It’s my go-to solution for model visualization:

from torchinfo import summary

# More detailed summary with batch size information
summary(model, input_size=(32, 3, 224, 224), verbose=2)

The output includes:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
SimpleModel                              --                        --
├─Conv2d: 1-1                            [32, 16, 224, 224]        448
├─ReLU: 1-2                              [32, 16, 224, 224]        --
├─MaxPool2d: 1-3                         [32, 16, 112, 112]        --
├─Linear: 1-4                            [32, 10]                  2,007,050
==========================================================================================
Total params: 2,007,498
Trainable params: 2,007,498
Non-trainable params: 0
Total mult-adds (G): 27.28
==========================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 154.14
Params size (MB): 7.66
Estimated Total Size (MB): 181.07
==========================================================================================

This provides even more useful information:

  • Hierarchical view of the model
  • Memory usage estimates
  • Multiply-add operations
  • Input/output sizes in MB

To install torchinfo:

pip install torchinfo

Read PyTorch Add Dimension

4 – Use torchviz for Visual Graphs

Sometimes, a visual representation is worth a thousand lines of text. The torchviz package creates graphical visualizations of your model:

from torchviz import make_dot

# Create a dummy input
x = torch.randn(1, 3, 224, 224)
y = model(x)

# Generate visualization
dot = make_dot(y, params=dict(model.named_parameters()))
dot.render("model_architecture", format="png")

This generates a computational graph of your model, showing how tensors flow through the network. It’s particularly useful for understanding complex architectures.

To install torchviz:

pip install torchviz

Read PyTorch Conv1d

5 – Use PyTorch’s hook_fn for Custom Summaries

For more flexibility, you can create a custom summary using PyTorch’s hook functions:

def hook_fn(module, input, output):
    print(f"Module: {module.__class__.__name__}")
    print(f"Input shape: {input[0].shape}")
    print(f"Output shape: {output.shape}")
    print("-------------------------------")

hooks = []
for name, module in model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Linear, nn.MaxPool2d)):
        hook = module.register_forward_hook(hook_fn)
        hooks.append(hook)

# Run a forward pass
x = torch.randn(1, 3, 224, 224)
output = model(x)

# Remove hooks
for hook in hooks:
    hook.remove()

This approach lets you track exactly what information you care about. It’s especially useful when debugging specific layers or tracking intermediate activations.

Handle Multiple Inputs

When your model takes multiple inputs, you can adapt the above methods slightly. For example, with torchinfo:

from torchinfo import summary

class MultiInputModel(nn.Module):
    def __init__(self):
        super(MultiInputModel, self).__init__()
        self.image_conv = nn.Conv2d(3, 16, 3)
        self.text_embedding = nn.Embedding(1000, 64)
        self.fc = nn.Linear(16*222*222 + 64*10, 10)

    def forward(self, image, text):
        img_feat = self.image_conv(image)
        img_feat = img_feat.view(img_feat.size(0), -1)

        text_feat = self.text_embedding(text)
        text_feat = text_feat.view(text_feat.size(0), -1)

        combined = torch.cat([img_feat, text_feat], dim=1)
        return self.fc(combined)

model = MultiInputModel()

# Define input sizes for each input
input_size = [(1, 3, 224, 224), (1, 10)]
summary(model, input_size=input_size)

This allows you to visualize models with multiple input branches.

Compare the Methods

Here’s a quick comparison of the methods to help you choose:

MethodProsCons
print()No dependencies, always availableLimited information, no parameter counts
torchsummaryEasy to use, parameter countsLimited formatting options
torchinfoMost comprehensive, memory estimatesSlight learning curve
torchvizVisual representation, great for sharingRequires GraphViz installation
custom hooksMaximum flexibilityRequires more code

In my experience, I use print() for quick checks, torchinfo for detailed analysis, and torchviz when I need to share model architecture with non-technical team members.

I hope you found this article helpful! Understanding your model architecture is a crucial step in the PyTorch development process, and these tools make it much easier to visualize what’s happening under the hood.

Remember to inspect your model summaries before training to catch dimensional issues early – it can save you hours of debugging time later!

You may also read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.