How to Reshape a Tensor in PyTorch?

Working with PyTorch tensors often requires changing their shapes to fit specific neural network architectures. I’ve been using PyTorch for years in various deep learning projects, and reshaping tensors is something I do almost daily.

While building a computer vision model to classify American landmarks, I needed to transform my batch of images to meet my CNN’s input requirements. Many developers struggle with this fundamental operation.

In this tutorial, I’ll show you different ways to reshape tensors in PyTorch based on my experience.

Tensor Reshaping Methods

Reshaping a tensor means changing its dimensions without altering its data or total number of elements. Think of it like rearranging a deck of cards, same cards, different arrangement.

For example, a tensor with shape [12] can be reshaped to [3, 4], [4, 3], [2, 6], etc., because the total number of elements (12) remains constant.

Before we dive in, let’s import PyTorch and create a simple tensor to work with:

import torch

# Creating a tensor with numbers 0 to 11
original_tensor = torch.arange(12)
print("Original tensor:", original_tensor)
print("Original shape:", original_tensor.shape)

Output:

Original tensor: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
Original shape: torch.Size([12])

I executed the above example code and added the screenshot below.

torch reshape

Now, let’s explore four different methods to reshape this tensor.

1- Use the reshape() Method

Python reshape() method is the easiest way to change a tensor’s shape. It creates a new tensor with the specified dimensions.

# Reshaping to a 3x4 tensor
reshaped_tensor = original_tensor.reshape(3, 4)
print("Reshaped tensor (3x4):", reshaped_tensor)
print("New shape:", reshaped_tensor.shape)

Output:

Reshaped tensor (3x4): tensor([[ 0,  1,  2,  3],
                               [ 4,  5,  6,  7],
                               [ 8,  9, 10, 11]])
New shape: torch.Size([3, 4])

I executed the above example code and added the screenshot below.

tensor reshape

The reshape() method is flexible and allows you to use -1 as a dimension, which tells PyTorch to automatically calculate that dimension based on the tensor’s total elements.

# Using -1 to automatically determine one dimension
auto_reshaped = original_tensor.reshape(-1, 4)
print("Auto-reshaped tensor:", auto_reshaped)
print("Auto-reshaped shape:", auto_reshaped.shape)

Output:

Auto-reshaped tensor: tensor([[ 0,  1,  2,  3],
                              [ 4,  5,  6,  7],
                              [ 8,  9, 10, 11]])
Auto-reshaped shape: torch.Size([3, 4])

I executed the above example code and added the screenshot below.

reshape tensor

I often use this when I know one dimension but want PyTorch to figure out the other automatically.

Read PyTorch Add Dimension

2- Use the view() Method

Python view() method works similarly to reshape() but with one important difference: it returns a new tensor that shares the same data with the original tensor.

# Using view to reshape to 2x6
view_tensor = original_tensor.view(2, 6)
print("View tensor (2x6):", view_tensor)
print("View shape:", view_tensor.shape)

Output:

View tensor (2x6): tensor([[ 0,  1,  2,  3,  4,  5],
                           [ 6,  7,  8,  9, 10, 11]])
View shape: torch.Size([2, 6])

Like reshape(), you can use -1 with view():

# Using -1 with view
auto_view = original_tensor.view(2, -1)
print("Auto-view shape:", auto_view.shape)

Output:

Auto-view shape: torch.Size([2, 6])

I prefer using view() when I want to ensure memory efficiency, especially when working with large datasets like satellite imagery of US national parks.

An important note: view() requires the tensor to be contiguous in memory. If your tensor isn’t contiguous (which can happen after certain operations), you’ll need to call .contiguous() first.

# Example with a non-contiguous tensor
transposed = view_tensor.transpose(0, 1)  # Makes tensor non-contiguous
print("Is transposed contiguous?", transposed.is_contiguous())

# This would fail: transposed.view(-1)
# Correct way:
contiguous_view = transposed.contiguous().view(-1)
print("Contiguous view shape:", contiguous_view.shape)

Output:

Is transposed contiguous? False
Contiguous view shape: torch.Size([12])

Check out PyTorch Conv1d

3- Use unsqueeze() for Adding Dimensions

Sometimes you need to add a dimension of size 1 to your tensor, especially when preparing data for models that expect specific input shapes. The unsqueeze() method is perfect for this.

# Adding a dimension at index 0
unsqueezed = original_tensor.unsqueeze(0)
print("Unsqueezed tensor:", unsqueezed)
print("Unsqueezed shape:", unsqueezed.shape)

Output:

Unsqueezed tensor: tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])
Unsqueezed shape: torch.Size([1, 12])

I often use this when preparing a single sample for batch processing. For instance, when processing a single image of the Statue of Liberty for a pre-trained CNN:

# Simulating a single image tensor (3 channels, 224x224)
single_image = torch.rand(3, 224, 224)
print("Single image shape:", single_image.shape)

# Adding batch dimension for model input
batched_image = single_image.unsqueeze(0)
print("Batched image shape:", batched_image.shape)

Output:

Single image shape: torch.Size([3, 224, 224])
Batched image shape: torch.Size([1, 3, 224, 224])

Read PyTorch View Tutorial

4- Use squeeze() for Removing Dimensions

Python’s squeeze() method does the opposite of unsqueeze() – it removes dimensions of size 1.

# Creating a tensor with singleton dimensions
tensor_with_ones = torch.rand(1, 3, 1, 2)
print("Original tensor shape:", tensor_with_ones.shape)

# Removing all singleton dimensions
squeezed = tensor_with_ones.squeeze()
print("Squeezed tensor shape:", squeezed.shape)

Output:

Original tensor shape: torch.Size([1, 3, 1, 2])
Squeezed tensor shape: torch.Size([3, 2])

You can also specify which dimension to squeeze:

# Removing only the first dimension
partially_squeezed = tensor_with_ones.squeeze(0)
print("Partially squeezed shape:", partially_squeezed.shape)

Output:

Partially squeezed shape: torch.Size([3, 1, 2])

I frequently use this method when processing model outputs that include batch dimensions I no longer need.

Check out PyTorch Conv3d

Practical Example: Image Batch Processing

Let’s put these methods together in a realistic example. Imagine we’re processing a batch of satellite images of US cities for a segmentation model:

# Simulating a batch of 8 images (each 3x64x64)
batch = torch.rand(8, 3, 64, 64)
print("Original batch shape:", batch.shape)

# Reshaping to combine batch and channels
reshaped_batch = batch.reshape(8, 3 * 64 * 64)
print("Reshaped for FC layer:", reshaped_batch.shape)

# Adding a time dimension for an RNN
time_aware_batch = batch.unsqueeze(1)
print("Time-aware batch:", time_aware_batch.shape)

# Flattening completely
flat_batch = batch.view(-1)
print("Flattened batch:", flat_batch.shape)

# Reshaping back to original
back_to_original = flat_batch.reshape(8, 3, 64, 64)
print("Back to original:", back_to_original.shape)

Output:

Original batch shape: torch.Size([8, 3, 64, 64])
Reshaped for FC layer: torch.Size([8, 12288])
Time-aware batch: torch.Size([8, 1, 3, 64, 64])
Flattened batch: torch.Size([98304])
Back to original: torch.Size([8, 3, 64, 64])

Read PyTorch Flatten

Performance Considerations

Based on my experience, here are some performance tips when reshaping tensors:

  1. view() is generally faster than reshape() when applicable because it doesn’t create a new copy of the data.
  2. If you’re repeatedly reshaping tensors in a loop, consider doing the reshaping once outside the loop if possible.
  3. For very large tensors (like those in video processing of US traffic camera footage), memory usage matters. Use view() when you can to avoid duplicating data.
  4. When working with non-contiguous tensors, reshape() might be more convenient as it handles the contiguity check internally.

Check out Create PyTorch Empty Tensor

Troubleshoot Common Reshape Errors

Let me show you how to troubleshoot some common reshape errors.

Size Mismatch Error

The most common error is trying to reshape to dimensions that don’t match the total number of elements:

try:
    original_tensor.reshape(5, 5)
except RuntimeError as e:
    print("Error:", e)

Output:

Error: shape '[5, 5]' is invalid for input of size 12

To fix this, ensure the product of your new dimensions equals the total number of elements in the original tensor.

View on Non-Contiguous Tensor

As mentioned earlier, view() requires contiguous tensors:

non_contiguous = original_tensor.reshape(3, 4).transpose(0, 1)
try:
    non_contiguous.view(-1)
except RuntimeError as e:
    print("Error:", e)
    print("Solution: Use contiguous() first or use reshape() instead")

In these cases, use either contiguous().view() or simply reshape().

Summary

Reshaping tensors is a fundamental operation in PyTorch that you’ll use constantly. In this tutorial, I’ve covered:

  • Using reshape() for general-purpose reshaping
  • Using view() for memory-efficient reshaping of contiguous tensors
  • Using unsqueeze() to add singleton dimensions
  • Using squeeze() to remove singleton dimensions

Each method has its place, and knowing when to use which will make your deep learning code more efficient and readable. Next time you’re preprocessing data for a US demographic analysis model or reformatting outputs from an object detection system, you’ll have the right tools to handle any tensor reshaping challenge.

Remember that while reshaping changes the organization of your data, it never changes the data itself or the total number of elements. Mastering these operations will give you greater flexibility in designing and implementing neural networks with PyTorch.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.