Cross Entropy Loss in PyTorch

Recently, I was working on a deep learning project where I needed to train a neural network for image classification. When selecting a loss function, I found that Cross Entropy Loss was recommended for multi-class classification problems, but I wasn’t entirely clear on how to implement it properly in PyTorch.

Cross-entropy loss is a widely used loss function in classification tasks, particularly for neural networks. In this article, I will explain what cross-entropy loss is, why it is important, and demonstrate various methods to implement it in PyTorch.

So let’s get started..!

Cross-Entropy Loss

Cross-Entropy Loss measures the performance of a classification model whose output is a probability value between 0 and 1. It increases as the predicted probability diverges from the actual label, penalizing confident incorrect predictions more heavily.

Simply put, it tells us how different our predicted probability distribution is from the true distribution of labels. When our model makes confident but wrong predictions, the loss is higher.

Implement Cross-Entropy Loss in PyTorch

Now, I will explain the implementation methods of Cross-Entropy Loss in PyTorch.

Method 1: Use nn.CrossEntropyLoss

The easiest way to implement Cross Entropy Loss in PyTorch is by using the built-in nn.CrossEntropyLoss class. This is what I use in most of my projects because it’s efficient and handles many common cases automatically.

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, x):
        return self.linear(x)

# Create model, loss function, and optimizer
input_size = 784  # For MNIST dataset (28x28 pixels)
num_classes = 10  # 10 digits (0-9)
model = SimpleModel(input_size, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Example inputs and targets
inputs = torch.randn(100, input_size)  # Batch of 100 examples
targets = torch.randint(0, num_classes, (100,))  # Class labels (0-9)

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)

# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Loss: {loss.item()}")

Output:

Loss: 2.493196487426758

You can refer to the screenshot below to see the output.

crossentropy_loss

Important things to note about nn.CrossEntropyLoss:

  • It combines nn.LogSoftmax and nn.NLLLoss in one single class
  • It expects raw logits as input (not softmax outputs)
  • Target should be class indices (not one-hot encoded)

Read Adam Optimizer PyTorch

Method 2: Use F.cross_entropy (Functional API)

If you prefer a more functional approach, PyTorch also provides the F.cross_entropy function, which works in the same way but doesn’t require creating a loss object.

import torch
import torch.nn.functional as F

# Same model as before
# ...

# Forward pass
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)

# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Loss: {loss.item():.4f}") 

Output:

Loss: 2.4561

You can refer to the screenshot below to see the output.

pytorch cross entropy loss

I sometimes use this method when I need a more flexible approach or when I’m experimenting with different loss functions and don’t want to create multiple loss objects.

Check out PyTorch nn Linear

Method 3: Manual Implementation

For educational purposes or when you need custom behavior, you can implement Cross Entropy Loss manually:

import torch
import torch.nn as nn
import torch.nn.functional as F

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, x):
        return self.linear(x)

# Manual cross-entropy loss function
def manual_cross_entropy(outputs, targets, reduction='mean'):
    # Apply log softmax
    log_softmax = F.log_softmax(outputs, dim=1)
    
    # Gather the log probabilities corresponding to the correct classes
    loss = -log_softmax.gather(1, targets.unsqueeze(1)).squeeze()

    # Apply reduction
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:  # 'none'
        return loss

# Hyperparameters
input_size = 784
num_classes = 10

# Initialize model and optimizer
model = SimpleModel(input_size, num_classes)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Example inputs and targets
inputs = torch.randn(100, input_size)  # batch of 100 examples
targets = torch.randint(0, num_classes, (100,))  # class labels

# Forward pass
outputs = model(inputs)

# Use the manual loss function
loss = manual_cross_entropy(outputs, targets)

# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Print loss
print(f"Manual Cross Entropy Loss: {loss.item():.4f}")

Output:

Manual Cross Entropy Loss: 2.6011

You can refer to the screenshot below to see the output.

cross entropy loss pytorch

I rarely need to implement it manually, but understanding the underlying mathematics helps me debug issues and customize the loss function when needed.

Read PyTorch Batch Normalization

Handle Imbalanced Classes with Weighted Cross-Entropy

When working with imbalanced datasets (where some classes have many more examples than others), I often use a weighted version of Cross Entropy Loss:

# Define class weights (inversely proportional to class frequencies)
class_counts = [100, 200, 50, 300, 150]  # Example class frequencies
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
class_weights = class_weights / class_weights.sum() * len(class_counts)

# Create weighted loss function
criterion = nn.CrossEntropyLoss(weight=class_weights)

This gives more importance to underrepresented classes, helping the model learn from all classes equally.

Check out PyTorch Load Model

Multi-Label Classification with BCE Loss

Sometimes I need to handle multi-label classification (where each example can belong to multiple classes). In such cases, I use Binary Cross-Entropy Loss:

# For multi-label classification
criterion = nn.BCEWithLogitsLoss()

# Multi-hot encoded targets (each example can have multiple labels)
targets = torch.zeros(100, num_classes)
for i in range(100):
    num_labels = torch.randint(1, 4, (1,)).item()  # 1-3 labels per example
    random_labels = torch.randperm(num_classes)[:num_labels]
    targets[i, random_labels] = 1

# Forward pass with multi-label targets
outputs = model(inputs)
loss = criterion(outputs, targets)

The BCEWithLogitsLoss combines a sigmoid activation with binary cross-entropy loss, making it numerically stable.

Read PyTorch Tensor to Numpy

Practical Example: Train an MNIST Classifier

Let me show you a complete example of training a simple classifier on the MNIST dataset using Cross Entropy Loss:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)

# Define the model
class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

model = MNISTModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
def train(model, train_loader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if batch_idx % 100 == 99:
                print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {running_loss/100:.4f}')
                running_loss = 0.0

# Test the model
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    print(f'Accuracy: {100 * correct / total:.2f}%')

# Run training and testing
train(model, train_loader, optimizer, criterion)
test(model, test_loader)

In this example, we’re using Cross Entropy Loss to train a convolutional neural network on the MNIST dataset. The model learns to classify handwritten digits from 0-9.

Understand Loss Values

When using Cross Entropy Loss, I’ve found that the scale of the loss values can sometimes be confusing. Here’s what I’ve learned:

  • For a perfectly predicted example, the loss approaches 0
  • For a random initialization, the loss is approximately ln(num_classes), so about 2.3 for a 10-class problem
  • A loss value of 0.3-0.5 typically indicates good learning
  • If your loss doesn’t decrease over time, you likely have an issue with your model or learning rate

I hope you found this article helpful. Cross-entropy loss is a powerful tool in the deep learning toolbox, and understanding how to use it effectively in PyTorch can greatly improve your model’s performance.

Other PyTorch articles you may also like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.