Recently, I was working on a deep learning project where I needed to train a neural network for image classification. When selecting a loss function, I found that Cross Entropy Loss was recommended for multi-class classification problems, but I wasn’t entirely clear on how to implement it properly in PyTorch.
Cross-entropy loss is a widely used loss function in classification tasks, particularly for neural networks. In this article, I will explain what cross-entropy loss is, why it is important, and demonstrate various methods to implement it in PyTorch.
So let’s get started..!
Cross-Entropy Loss
Cross-Entropy Loss measures the performance of a classification model whose output is a probability value between 0 and 1. It increases as the predicted probability diverges from the actual label, penalizing confident incorrect predictions more heavily.
Simply put, it tells us how different our predicted probability distribution is from the true distribution of labels. When our model makes confident but wrong predictions, the loss is higher.
Implement Cross-Entropy Loss in PyTorch
Now, I will explain the implementation methods of Cross-Entropy Loss in PyTorch.
Method 1: Use nn.CrossEntropyLoss
The easiest way to implement Cross Entropy Loss in PyTorch is by using the built-in nn.CrossEntropyLoss class. This is what I use in most of my projects because it’s efficient and handles many common cases automatically.
import torch
import torch.nn as nn
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self, input_size, num_classes):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, x):
return self.linear(x)
# Create model, loss function, and optimizer
input_size = 784 # For MNIST dataset (28x28 pixels)
num_classes = 10 # 10 digits (0-9)
model = SimpleModel(input_size, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Example inputs and targets
inputs = torch.randn(100, input_size) # Batch of 100 examples
targets = torch.randint(0, num_classes, (100,)) # Class labels (0-9)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")Output:
Loss: 2.493196487426758You can refer to the screenshot below to see the output.

Important things to note about nn.CrossEntropyLoss:
- It combines
nn.LogSoftmaxandnn.NLLLossin one single class - It expects raw logits as input (not softmax outputs)
- Target should be class indices (not one-hot encoded)
Method 2: Use F.cross_entropy (Functional API)
If you prefer a more functional approach, PyTorch also provides the F.cross_entropy function, which works in the same way but doesn’t require creating a loss object.
import torch
import torch.nn.functional as F
# Same model as before
# ...
# Forward pass
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Loss: {loss.item():.4f}") Output:
Loss: 2.4561You can refer to the screenshot below to see the output.

I sometimes use this method when I need a more flexible approach or when I’m experimenting with different loss functions and don’t want to create multiple loss objects.
Check out PyTorch nn Linear
Method 3: Manual Implementation
For educational purposes or when you need custom behavior, you can implement Cross Entropy Loss manually:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self, input_size, num_classes):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(input_size, num_classes)
def forward(self, x):
return self.linear(x)
# Manual cross-entropy loss function
def manual_cross_entropy(outputs, targets, reduction='mean'):
# Apply log softmax
log_softmax = F.log_softmax(outputs, dim=1)
# Gather the log probabilities corresponding to the correct classes
loss = -log_softmax.gather(1, targets.unsqueeze(1)).squeeze()
# Apply reduction
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
else: # 'none'
return loss
# Hyperparameters
input_size = 784
num_classes = 10
# Initialize model and optimizer
model = SimpleModel(input_size, num_classes)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Example inputs and targets
inputs = torch.randn(100, input_size) # batch of 100 examples
targets = torch.randint(0, num_classes, (100,)) # class labels
# Forward pass
outputs = model(inputs)
# Use the manual loss function
loss = manual_cross_entropy(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print loss
print(f"Manual Cross Entropy Loss: {loss.item():.4f}")Output:
Manual Cross Entropy Loss: 2.6011You can refer to the screenshot below to see the output.

I rarely need to implement it manually, but understanding the underlying mathematics helps me debug issues and customize the loss function when needed.
Read PyTorch Batch Normalization
Handle Imbalanced Classes with Weighted Cross-Entropy
When working with imbalanced datasets (where some classes have many more examples than others), I often use a weighted version of Cross Entropy Loss:
# Define class weights (inversely proportional to class frequencies)
class_counts = [100, 200, 50, 300, 150] # Example class frequencies
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
class_weights = class_weights / class_weights.sum() * len(class_counts)
# Create weighted loss function
criterion = nn.CrossEntropyLoss(weight=class_weights)This gives more importance to underrepresented classes, helping the model learn from all classes equally.
Check out PyTorch Load Model
Multi-Label Classification with BCE Loss
Sometimes I need to handle multi-label classification (where each example can belong to multiple classes). In such cases, I use Binary Cross-Entropy Loss:
# For multi-label classification
criterion = nn.BCEWithLogitsLoss()
# Multi-hot encoded targets (each example can have multiple labels)
targets = torch.zeros(100, num_classes)
for i in range(100):
num_labels = torch.randint(1, 4, (1,)).item() # 1-3 labels per example
random_labels = torch.randperm(num_classes)[:num_labels]
targets[i, random_labels] = 1
# Forward pass with multi-label targets
outputs = model(inputs)
loss = criterion(outputs, targets)The BCEWithLogitsLoss combines a sigmoid activation with binary cross-entropy loss, making it numerically stable.
Practical Example: Train an MNIST Classifier
Let me show you a complete example of training a simple classifier on the MNIST dataset using Cross Entropy Loss:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Define transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Load MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)
# Define the model
class MNISTModel(nn.Module):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return x
model = MNISTModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
def train(model, train_loader, optimizer, criterion, epochs=5):
model.train()
for epoch in range(epochs):
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 99:
print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {running_loss/100:.4f}')
running_loss = 0.0
# Test the model
def test(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f'Accuracy: {100 * correct / total:.2f}%')
# Run training and testing
train(model, train_loader, optimizer, criterion)
test(model, test_loader)In this example, we’re using Cross Entropy Loss to train a convolutional neural network on the MNIST dataset. The model learns to classify handwritten digits from 0-9.
Understand Loss Values
When using Cross Entropy Loss, I’ve found that the scale of the loss values can sometimes be confusing. Here’s what I’ve learned:
- For a perfectly predicted example, the loss approaches 0
- For a random initialization, the loss is approximately ln(num_classes), so about 2.3 for a 10-class problem
- A loss value of 0.3-0.5 typically indicates good learning
- If your loss doesn’t decrease over time, you likely have an issue with your model or learning rate
I hope you found this article helpful. Cross-entropy loss is a powerful tool in the deep learning toolbox, and understanding how to use it effectively in PyTorch can greatly improve your model’s performance.
Other PyTorch articles you may also like:

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.