PyTorch Early Stopping: Prevent Overfitting in Your Models

Recently, I was working on a deep learning project where my model was performing great on the training data but poorly on the validation set. The issue was, my model was overfitting. This is where early stopping comes to the rescue! In this article, I will show you how to implement early stopping in PyTorch to prevent your models from overfitting and help you build more robust neural networks.

Early stopping is one of the most effective techniques for regularization in deep learning. I will cover various implementation approaches, ranging from a simple DIY method to established libraries. Let’s get in!

Early Stopping

When training neural networks, we often face a common challenge: the model performs increasingly well on training data but starts performing worse on validation data after a certain point. This is overfitting in action.

Early stopping monitors validation performance during training and stops the process when the model stops improving on validation data. It’s like knowing when to stop cooking pasta – if you wait too long, it gets mushy and unappetizing!

Read PyTorch RNN

Method 1: DIY Early Stopping in PyTorch

The simplest way to implement early stopping is to create your monitoring mechanism. Here’s a easy implementation I often use in my projects:

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

And here’s how you’d use it in your training loop:

# Initialize early stopping
early_stopping = EarlyStopping(patience=10, verbose=True)

# Training loop
for epoch in range(num_epochs):
    # Training phase
    model.train()
    for X_batch, y_batch in train_loader:
        # Training steps...

    # Validation phase
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for X_batch, y_batch in valid_loader:
            # Validation steps...

    # Check early stopping
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

You can see the output in the screenshot below.

pytorch early stopping

This method is flexible and gives you full control over the early stopping logic.

Check out PyTorch Leaky ReLU

Method 2: Use PyTorch Lightning

If you’re already using PyTorch Lightning (which I highly recommend for structured deep learning projects), early stopping is built right in:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

# Define the early stopping callback
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    verbose=True,
    mode='min'
)

# Use it with the trainer
trainer = Trainer(
    max_epochs=100,
    callbacks=[early_stopping]
)

# Train the model
trainer.fit(model, train_loader, val_loader)

You can see the output in the screenshot below.

early stopping pytorch

PyTorch Lightning’s implementation is robust and integrates well with the rest of the library’s features.

Read Jax Vs PyTorch

Method 3: Use Ignite

Ignite is another high-level library built on top of PyTorch that provides an early stopping handler:

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import EarlyStopping

def score_function(engine):
    return -engine.state.metrics['loss']

handler = EarlyStopping(
    patience=10,
    score_function=score_function,
    trainer=trainer
)

evaluator.add_event_handler(Events.COMPLETED, handler)

Ignite’s implementation is event-based and works well within their ecosystem.

Check out PyTorch nn Conv2d

Fine-tuning Your Early Stopping Strategy

The key parameter in early stopping is patience, the number of epochs to wait before stopping after validation performance stops improving. This requires some experimentation:

  • Low patience (1-3 epochs): Training might stop too early, before the model reaches optimal performance
  • High patience (15-20 epochs): You might waste computation time, but have a better chance of finding the best model
  • Moderate patience (5-10 epochs): Usually a good compromise for most projects

I typically start with a patience of 10 and adjust based on how training progresses.

Read PyTorch Reshape Tensor

Real-world Example: Stock Price Prediction

Let’s use early stopping in a practical example, a model predicting stock prices for major U.S. companies like Apple and Amazon:

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader, TensorDataset

# Assume we have stock data loaded
# Load and preprocess data...

# Define a simple LSTM model
class StockPredictor(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=32, num_layers=2, output_dim=1):
        super(StockPredictor, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), 32).to(x.device)
        c0 = torch.zeros(2, x.size(0), 32).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

# Create model, criterion, optimizer
model = StockPredictor()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize early stopping
early_stopping = EarlyStopping(patience=10, verbose=True)

# Training loop
for epoch in range(100):
    # Training code...

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for X_batch, y_batch in valid_loader:
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            val_loss += loss.item()

    val_loss /= len(valid_loader)
    print(f'Epoch {epoch}: val_loss = {val_loss:.6f}')

    # Check early stopping
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print(f"Early stopping triggered at epoch {epoch}")
        break

# Load the best model
model.load_state_dict(torch.load('checkpoint.pt'))

In this stock prediction example, early stopping helps prevent the model from memorizing noise in the training data, which is crucial for financial predictions where overfitting can lead to poor real-world performance.

Early stopping is an essential technique in my deep learning toolkit. It not only helps prevent overfitting but also saves computational resources by avoiding unnecessary training iterations.

Remember that early stopping works best when combined with other regularization techniques like dropout and weight decay. In my experience, this combination leads to more robust models that generalize better to unseen data.

Other Python articles you may also like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.