How to Use PyTorch Flatten for Neural Network Models

In my decade-plus journey as a Python developer, I’ve found that reshaping tensors is a crucial operation when building neural networks. One of the most common reshaping operations I perform is flattening multi-dimensional data into a 1D or 2D tensor.

PyTorch’s Flatten layer is a simple yet useful tool that I use regularly in my deep learning projects. It’s especially useful when transitioning from convolutional layers to fully connected layers.

Today, I will guide you through everything you need to know about the Flatten operation in PyTorch, covering basic usage as well as advanced implementation techniques.

Methods to Flatten PyTorch

PyTorch Flatten is a layer that reshapes a tensor by “flattening” it across specified dimensions. It takes a multi-dimensional input and converts it to a lower-dimensional form without changing the data.

I use this operation primarily when:

  • Converting feature maps from CNNs to vectors for fully connected layers
  • Preparing batched data for linear transformations
  • Simplifying model architectures by standardizing dimensions

Let’s dive into how to implement this useful operation in our PyTorch models.

1: Use torch.nn.Flatten Module

The easiest way to flatten tensors in PyTorch is by using Python’s built-in torch.nn.Flatten module. I prefer this approach in most of my models because it’s cleaner and more maintainable.

import torch
import torch.nn as nn

# Create a sample 4D input tensor (batch_size, channels, height, width)
batch_size = 32
input_tensor = torch.randn(batch_size, 3, 224, 224)

# Create a flatten layer (default starts_dim=1, end_dim=-1)
flatten_layer = nn.Flatten()

# Apply flattening
output_tensor = flatten_layer(input_tensor)

print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")

When I run this code, I get:

Input shape: torch.Size([32, 3, 224, 224])
Output shape: torch.Size([32, 150528])

You can see the output in the screenshot below.

pytorch flatten

Notice how it preserved the batch dimension (index 0) but flattened all other dimensions (3×224×224 = 150,528).

Customize Flattening Dimensions

Sometimes I need to flatten only specific dimensions. The nn.Flatten module accepts two parameters:

  • start_dim: first dimension to flatten (default: 1)
  • end_dim: last dimension to flatten (default: -1)

Here’s how I customize the flattening:

# Flatten only the spatial dimensions (height and width)
spatial_flatten = nn.Flatten(start_dim=2, end_dim=3)
output = spatial_flatten(input_tensor)
print(f"Spatial flattened shape: {output.shape}")

Result:

Spatial flattened shape: torch.Size([32, 3, 50176])

This preserves both batch size and channels but flattens height and width.

2: Use torch.flatten Function

When I need a quick, one-off flattening operation without defining a separate layer, I use the functional approach with torch.flatten() in Python:

import torch

# Create a sample tensor
x = torch.randn(8, 4, 16, 16)

# Flatten from dimension 1 to the end
flattened = torch.flatten(x, start_dim=1)
print(f"Original shape: {x.shape}")
print(f"Flattened shape: {flattened.shape}")

Output:

Original shape: torch.Size([8, 4, 16, 16])
Flattened shape: torch.Size([8, 1024])

You can see the output in the screenshot below.

torch.flatten

This method is particularly useful in forward functions when I need to flatten tensors on the fly.

Read PyTorch Batch Normalization

3: Use view() or reshape() Methods

Before torch.nn.Flatten was introduced, I used to flatten tensors using the view() or reshape() methods:

# Using view() method
x = torch.randn(16, 3, 28, 28)
flattened_view = x.view(16, -1)  # -1 automatically calculates the correct size
print(f"Flattened with view(): {flattened_view.shape}")

# Using reshape() method
flattened_reshape = x.reshape(16, -1)
print(f"Flattened with reshape(): {flattened_reshape.shape}")

Result:

Flattened with view(): torch.Size([16, 2352])
Flattened with reshape(): torch.Size([16, 2352])

You can see the output in the screenshot below.

torch flatten

While these methods work, I now prefer using nn.Flatten for better readability and consistency in model definitions.

Check out PyTorch Load Model

Practical Example: CNN with Flatten Layer

Let me show you how I incorporate Flatten in a real-world CNN model for image classification:

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        # Convolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Flatten layer - crucial for transition to fully connected layers
        self.flatten = nn.Flatten()

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * 56 * 56, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)  # 10 classes for CIFAR-10
        )

    def forward(self, x):
        # Pass input through convolutional layers
        x = self.conv_layers(x)

        # Flatten the feature maps
        x = self.flatten(x)

        # Pass through fully connected layers
        x = self.fc_layers(x)

        return x

# Create model and check output with a sample input
model = SimpleCNN()
sample_input = torch.randn(1, 3, 224, 224)  # Single image, RGB, 224x224
output = model(sample_input)
print(f"Model output shape: {output.shape}")

Output:

Model output shape: torch.Size([1, 10])

In this example, the Flatten layer serves as a bridge between the convolutional layers (which output 3D tensors) and the fully connected layers (which expect 2D tensors).

Read PyTorch Tensor to Numpy

Common Mistakes When Using Flatten

Throughout my years of experience, I’ve encountered several common mistakes when using the Flatten operation:

  1. Forgetting about batch dimension: Always remember that in most deep learning scenarios, the first dimension is the batch size and should be preserved.
  2. Incorrect dimension calculation: When connecting flattened output to a linear layer, ensure you calculate the correct input size:
# Incorrect
x = torch.randn(8, 64, 7, 7)
flat = nn.Flatten()(x)
linear = nn.Linear(64*7*7, 100)  # Correct input size!

# The error occurs because we forgot that batch dimension is not flattened
  1. Flattening at the wrong point: Flattening too early in your network can lose spatial information that convolutional layers could exploit.

Read PyTorch Tensor to Numpy

Performance Considerations

In my experience, nn.Flatten has minimal computational overhead. However, I’ve found that flattening large tensors can sometimes cause memory spikes during training.

If you’re working with extremely large feature maps, consider downsampling before flattening or using techniques like global average pooling to reduce dimensions more efficiently.

When to Use PyTorch Flatten

Over the years, I’ve developed a sense for when flattening is appropriate:

  • Use Flatten when: Transitioning from convolutional to fully-connected layers, preparing data for RNNs, or implementing certain types of attention mechanisms.
  • Avoid Flatten when: Working with transformers (which typically maintain spatial dimensions), implementing fully convolutional networks, or when spatial information is critical throughout the entire network.

PyTorch’s Flatten operation is an essential tool in any deep learning practitioner’s toolkit. By understanding how to use it effectively, you’ll be able to design more flexible and powerful neural network architectures.

Other PyTorch-related tutorials you may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.