CutMix Data Augmentation in Keras

I have spent years building deep learning models, and if there is one thing I’ve learned, it is that your model is only as good as your data.

When I first started with Keras, I relied on simple flips and rotations, but my models often struggled with complex, real-world images.

That all changed when I discovered CutMix, a powerful technique that cuts a patch from one image and pastes it into another, forcing the model to learn localized features.

In this tutorial, I will show you exactly how to implement CutMix using the latest Keras API and KerasCV for your image classification projects.

Use the KerasCV CutMix Layer in Python Keras

I find that the easiest way to get started is by using the built-in layers in KerasCV, as they handle the heavy lifting of label mixing automatically.

This method is incredibly efficient because it integrates directly into your model’s preprocessing pipeline and runs on the GPU.

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_cv
import tensorflow as tf

# Define a simple CutMix layer for US-based fruit classification (e.g., Apples vs Oranges)
def apply_keras_cv_cutmix(image_batch, label_batch):
    # CutMix requires one-hot encoded labels
    num_classes = 2
    labels = tf.one_hot(label_batch, num_classes)
    
    # Initialize the CutMix layer
    cut_mix_layer = keras_cv.layers.CutMix()
    
    # Apply the transformation to the batch
    samples = {"images": image_batch, "labels": labels}
    output = cut_mix_layer(samples, training=True)
    
    return output["images"], output["labels"]

# Example usage with dummy data representing Florida Oranges and Washington Apples
import numpy as np
images = np.random.random((16, 224, 224, 3)).astype("float32")
labels = np.random.randint(0, 2, (16,))
mixed_imgs, mixed_lbls = apply_keras_cv_cutmix(images, labels)
print(f"Mixed Image Shape: {mixed_imgs.shape}")

Implement a Custom CutMix Logic in Python Keras

I often use a custom implementation when I need full control over the bounding box logic or when I am working with older versions of TensorFlow.

This method involves manually sampling from a Beta distribution to decide the patch size and then swapping pixel values between two images in a batch.

import tensorflow as tf
from tensorflow.keras import backend as K

def get_cutmix_sample(image1, image2, label1, label2, alpha=1.0):
    # Determine the mixing ratio from a Beta distribution
    lam = np.random.beta(alpha, alpha)
    
    # Calculate bounding box coordinates for the patch
    img_h, img_w = image1.shape[0], image1.shape[1]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int32(img_w * cut_rat)
    cut_h = np.int32(img_h * cut_rat)

    # Randomly pick the center of the box
    cx = np.random.randint(img_w)
    cy = np.random.randint(img_h)

    # Define the boundaries
    x1 = np.clip(cx - cut_w // 2, 0, img_w)
    y1 = np.clip(cy - cut_h // 2, 0, img_h)
    x2 = np.clip(cx + cut_w // 2, 0, img_w)
    y2 = np.clip(cy + cut_h // 2, 0, img_h)

    # Replace the patch in image1 with pixels from image2
    image1[y1:y2, x1:x2, :] = image2[y1:y2, x1:x2, :]
    
    # Adjust the lambda based on the actual pixel count
    actual_lam = 1 - ((x2 - x1) * (y2 - y1) / (img_w * img_h))
    mixed_label = actual_lam * label1 + (1 - actual_lam) * label2
    
    return image1, mixed_label

# Let's test it with two sample images
img_a, lbl_a = np.ones((224, 224, 3)), np.array([1, 0])
img_b, lbl_b = np.zeros((224, 224, 3)), np.array([0, 1])
new_img, new_lbl = get_cutmix_sample(img_a, img_b, lbl_a, lbl_b)

Integrate CutMix into tf.data Pipelines in Python Keras

I prefer using tf.data for large-scale training because it allows for asynchronous data loading and parallel preprocessing.

In this approach, you zip two instances of your dataset together so that you always have a pair of images ready for the CutMix operation.

def prepare_cutmix_dataset(dataset, batch_size, num_classes):
    # Zip the dataset with itself to get pairs
    ds_1 = dataset.shuffle(1024).batch(batch_size)
    ds_2 = dataset.shuffle(1024).batch(batch_size)
    
    train_ds = tf.data.Dataset.zip((ds_1, ds_2))
    
    def cutmix_fn(batch1, batch2):
        images1, labels1 = batch1
        images2, labels2 = batch2
        
        # Using KerasCV inside the map function for speed
        cutmix = keras_cv.layers.CutMix()
        
        # Ensure labels are one-hot encoded for US car classification labels
        labels1 = tf.one_hot(labels1, num_classes)
        labels2 = tf.one_hot(labels2, num_classes)
        
        # Concatenate and apply CutMix
        input_data = {"images": images1, "labels": labels1}
        # In a real scenario, you'd mix images1 and images2
        # For simplicity, KerasCV CutMix handles internal batch shuffling
        return cutmix(input_data, training=True)

    return train_ds.map(cutmix_fn, num_parallel_calls=tf.data.AUTOTUNE)

# Example of how you would call this in your training loop
# train_ds = prepare_cutmix_dataset(raw_dataset, 32, 10)

You can refer to the screenshot below to see the output.

CutMix Data Augmentation in Keras

Applying CutMix data augmentation in Keras is a fantastic way to boost the performance of your image classifiers without needing more raw data.

I hope you found this tutorial helpful and that you can now implement these methods in your own computer vision projects.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.