Recently, I was working on a deep learning project where my model was performing great on the training data but poorly on the validation set. The issue was, my model was overfitting. This is where early stopping comes to the rescue! In this article, I will show you how to implement early stopping in PyTorch to prevent your models from overfitting and help you build more robust neural networks.
Early stopping is one of the most effective techniques for regularization in deep learning. I will cover various implementation approaches, ranging from a simple DIY method to established libraries. Let’s get in!
Early Stopping
When training neural networks, we often face a common challenge: the model performs increasingly well on training data but starts performing worse on validation data after a certain point. This is overfitting in action.
Early stopping monitors validation performance during training and stops the process when the model stops improving on validation data. It’s like knowing when to stop cooking pasta – if you wait too long, it gets mushy and unappetizing!
Read PyTorch RNN
Method 1: DIY Early Stopping in PyTorch
The simplest way to implement early stopping is to create your monitoring mechanism. Here’s a easy implementation I often use in my projects:
class EarlyStopping:
def __init__(self, patience=7, min_delta=0, verbose=False):
self.patience = patience
self.min_delta = min_delta
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.min_delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), 'checkpoint.pt')
self.val_loss_min = val_lossAnd here’s how you’d use it in your training loop:
# Initialize early stopping
early_stopping = EarlyStopping(patience=10, verbose=True)
# Training loop
for epoch in range(num_epochs):
# Training phase
model.train()
for X_batch, y_batch in train_loader:
# Training steps...
# Validation phase
model.eval()
val_loss = 0
with torch.no_grad():
for X_batch, y_batch in valid_loader:
# Validation steps...
# Check early stopping
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping triggered")
breakYou can see the output in the screenshot below.

This method is flexible and gives you full control over the early stopping logic.
Check out PyTorch Leaky ReLU
Method 2: Use PyTorch Lightning
If you’re already using PyTorch Lightning (which I highly recommend for structured deep learning projects), early stopping is built right in:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
# Define the early stopping callback
early_stopping = EarlyStopping(
monitor='val_loss',
patience=10,
verbose=True,
mode='min'
)
# Use it with the trainer
trainer = Trainer(
max_epochs=100,
callbacks=[early_stopping]
)
# Train the model
trainer.fit(model, train_loader, val_loader)You can see the output in the screenshot below.

PyTorch Lightning’s implementation is robust and integrates well with the rest of the library’s features.
Read Jax Vs PyTorch
Method 3: Use Ignite
Ignite is another high-level library built on top of PyTorch that provides an early stopping handler:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.handlers import EarlyStopping
def score_function(engine):
return -engine.state.metrics['loss']
handler = EarlyStopping(
patience=10,
score_function=score_function,
trainer=trainer
)
evaluator.add_event_handler(Events.COMPLETED, handler)Ignite’s implementation is event-based and works well within their ecosystem.
Check out PyTorch nn Conv2d
Fine-tuning Your Early Stopping Strategy
The key parameter in early stopping is patience, the number of epochs to wait before stopping after validation performance stops improving. This requires some experimentation:
- Low patience (1-3 epochs): Training might stop too early, before the model reaches optimal performance
- High patience (15-20 epochs): You might waste computation time, but have a better chance of finding the best model
- Moderate patience (5-10 epochs): Usually a good compromise for most projects
I typically start with a patience of 10 and adjust based on how training progresses.
Real-world Example: Stock Price Prediction
Let’s use early stopping in a practical example, a model predicting stock prices for major U.S. companies like Apple and Amazon:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader, TensorDataset
# Assume we have stock data loaded
# Load and preprocess data...
# Define a simple LSTM model
class StockPredictor(nn.Module):
def __init__(self, input_dim=1, hidden_dim=32, num_layers=2, output_dim=1):
super(StockPredictor, self).__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h0 = torch.zeros(2, x.size(0), 32).to(x.device)
c0 = torch.zeros(2, x.size(0), 32).to(x.device)
out, _ = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
# Create model, criterion, optimizer
model = StockPredictor()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Initialize early stopping
early_stopping = EarlyStopping(patience=10, verbose=True)
# Training loop
for epoch in range(100):
# Training code...
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for X_batch, y_batch in valid_loader:
y_pred = model(X_batch)
loss = criterion(y_pred, y_batch)
val_loss += loss.item()
val_loss /= len(valid_loader)
print(f'Epoch {epoch}: val_loss = {val_loss:.6f}')
# Check early stopping
early_stopping(val_loss, model)
if early_stopping.early_stop:
print(f"Early stopping triggered at epoch {epoch}")
break
# Load the best model
model.load_state_dict(torch.load('checkpoint.pt'))In this stock prediction example, early stopping helps prevent the model from memorizing noise in the training data, which is crucial for financial predictions where overfitting can lead to poor real-world performance.
Early stopping is an essential technique in my deep learning toolkit. It not only helps prevent overfitting but also saves computational resources by avoiding unnecessary training iterations.
Remember that early stopping works best when combined with other regularization techniques like dropout and weight decay. In my experience, this combination leads to more robust models that generalize better to unseen data.
Other Python articles you may also like:

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.