If you’ve been working with PyTorch for deep learning projects, you’ve likely encountered situations where you need to combine multiple tensors into a single one. This is where PyTorch’s torch.cat() function becomes invaluable.
In my decade-plus experience with Python and PyTorch, I’ve found that mastering tensor concatenation is essential for efficient data manipulation in neural network workflows.
I remember when I first started building image classification models for a U.S. healthcare project. I struggled with batch processing until I truly understood how to use it properly.
Today, I’ll share everything you need to know about this function, from basic usage to advanced techniques that I’ve refined over the years of hands-on development.
What is torch.cat()?
Python torch.cat() function in PyTorch is designed to concatenate a sequence of tensors along a specified dimension. Think of it as stacking tensors together like building blocks.
The basic syntax is easy:
torch.cat(tensors, dim=0, *, out=None)Where:
tensorsis a sequence of tensors to concatenatedimspecifies the dimension along which to concatenateoutis an optional output tensor
Read PyTorch Dataloader
Basic Usage of torch.cat()
Let’s start with a simple example to understand how torch.cat() works:
import torch
# Create two 2D tensors
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
# Concatenate along dimension 0 (rows)
result_dim0 = torch.cat((tensor1, tensor2), dim=0)
print("Concatenated along dim 0:\n", result_dim0)
# Concatenate along dimension 1 (columns)
result_dim1 = torch.cat((tensor1, tensor2), dim=1)
print("Concatenated along dim 1:\n", result_dim1)Output:
Concatenated along dim 0:
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
Concatenated along dim 1:
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])I executed the above example code and added the screenshot below.

Understanding dimensions is crucial. When dim=0, tensors are stacked vertically (adding rows). When dim=1, they’re stacked horizontally (adding columns).
Concatenate Multiple Tensors
One of the strengths of torch.cat() is its ability to handle more than two tensors at once:
import torch
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
t3 = torch.tensor([7, 8, 9])
# Concatenate all three tensors
result = torch.cat((t1, t2, t3), dim=0)
print(result) Output:
tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])I executed the above example code and added the screenshot below.

I’ve often used this approach when working with U.S. census data, where I needed to combine multiple demographic feature tensors into a single input for my model.
Check out PyTorch Binary Cross Entropy
Dimension Requirements
An important rule to remember: when concatenating tensors along a specific dimension, all tensors must have the same size in all dimensions except the dimension along which you’re concatenating.
Here’s an example of what works and what doesn’t:
# These will work
t1 = torch.zeros(2, 3, 4)
t2 = torch.ones(2, 3, 4)
result = torch.cat((t1, t2), dim=0) # Shape: [4, 3, 4]
# This will fail
t3 = torch.ones(2, 5, 4)
# torch.cat((t1, t3), dim=0) # Error: sizes of tensors must match except in dimension 0Real-World Example: Batch Processing
Let’s look at a practical example I’ve used when training a CNN on a U.S. traffic sign dataset:
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(16 * 14 * 14, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = x.view(-1, 16 * 14 * 14)
x = self.fc(x)
return x
# Let's say we have batches of images from different sources
batch1 = torch.randn(8, 3, 28, 28) # 8 images from source 1
batch2 = torch.randn(4, 3, 28, 28) # 4 images from source 2
# Combine batches
combined_batch = torch.cat((batch1, batch2), dim=0) # Now we have 12 images
print(f"Combined batch shape: {combined_batch.shape}")
# Process through our model
model = SimpleCNN()
output = model(combined_batch)
print(f"Output shape: {output.shape}") Output:
Combined batch shape: torch.Size([12, 3, 28, 28])
Output shape: torch.Size([12, 10])I executed the above example code and added the screenshot below.

This technique is particularly useful when you have data coming from multiple sources that you want to process together.
Use torch.cat() with Different Tensor Types
When working with different tensor types, PyTorch tries to perform automatic type conversion:
# Float and integer tensors
float_tensor = torch.tensor([1.1, 2.2, 3.3], dtype=torch.float32)
int_tensor = torch.tensor([4, 5, 6], dtype=torch.int32)
# PyTorch will convert int_tensor to float32
result = torch.cat((float_tensor, int_tensor), dim=0)
print(result) # tensor([1.1000, 2.2000, 3.3000, 4.0000, 5.0000, 6.0000])
print(result.dtype) # torch.float32However, I recommend explicitly converting tensors to the same type for clarity and to avoid unexpected behavior.
GPU Considerations
When working with PyTorch on GPU, all tensors being concatenated need to be on the same device:
if torch.cuda.is_available():
cuda_tensor1 = torch.tensor([1, 2, 3], device='cuda')
cuda_tensor2 = torch.tensor([4, 5, 6], device='cuda')
# This works fine
result = torch.cat((cuda_tensor1, cuda_tensor2), dim=0)
# This would fail
cpu_tensor = torch.tensor([7, 8, 9], device='cpu')
# torch.cat((cuda_tensor1, cpu_tensor), dim=0) # Error: expected all tensors to be on the same deviceI learned this the hard way when deploying a model for a U.S. retailer’s recommendation system that needed to handle both CPU and GPU tensors.
Check out PyTorch MNIST Tutorial
torch.cat() vs. torch.stack()
It’s worth distinguishing between torch.cat() and another common function, torch.stack():
t1 = torch.tensor([1, 2, 3])
t2 = torch.tensor([4, 5, 6])
# Concatenation joins existing dimensions
cat_result = torch.cat((t1, t2), dim=0)
print(f"cat result shape: {cat_result.shape}") # torch.Size([6])
# Stacking adds a new dimension
stack_result = torch.stack((t1, t2), dim=0)
print(f"stack result shape: {stack_result.shape}") # torch.Size([2, 3])I use torch.cat() when I want to merge existing dimensions, and torch.stack() when I want to create a new dimension.
Performance Tips
After years of optimizing PyTorch code, I’ve found these performance tips for torch.cat():
- Pre-allocate when possible: If you’re repeatedly concatenating in a loop, pre-allocate a tensor and use the
outparameter:
result = torch.zeros(10, 3, 4)
for i in range(5):
t1 = torch.ones(1, 3, 4) * i
t2 = torch.ones(1, 3, 4) * (i + 5)
torch.cat((t1, t2), dim=0, out=result[i*2:(i+1)*2])- Batch operations: It’s usually more efficient to concatenate many tensors at once rather than repeatedly concatenating pairs.
- Consider using lists first: For dynamically growing collections, append to a list and perform a single concatenation at the end:
tensor_list = []
for i in range(10):
tensor_list.append(torch.ones(1, 5) * i)
# Single concatenation at the end
result = torch.cat(tensor_list, dim=0)Check out PyTorch Fully Connected Layer
Common Errors and Solutions
In my experience, these are the most common issues with torch.cat():
- Dimension mismatch: Ensure all tensors have matching dimensions except the one you’re concatenating along.
- Empty tensor handling: Be careful when concatenating empty tensors:
empty = torch.tensor([])
nonempty = torch.tensor([1, 2, 3])
# This will fail
# torch.cat((empty, nonempty), dim=0)
# Proper way to handle empty tensors
if empty.numel() == 0:
result = nonempty
else:
result = torch.cat((empty, nonempty), dim=0)- Device mismatch: Ensure all tensors are on the same device (CPU or GPU).
Over the years, handling these errors properly has saved me countless debugging hours.
When working with PyTorch’s torch.cat() function, I’ve found it to be an essential tool for tensor manipulation in deep learning workflows. From basic tensor joining to complex model architecture implementations, understanding this function thoroughly can significantly improve your PyTorch code efficiency.
Remember that the dimension parameter is crucial; it determines how your tensors will be joined together. And always ensure your tensors have compatible shapes except along the concatenation dimension.
You may like to read:

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.