Implement Masked Image Modeling with Keras Autoencoders

I’ve often found that the best way for a model to “learn” an image is by trying to fix a broken one.

Masked Image Modeling (MIM) is a fascinating technique where we intentionally hide parts of an image and ask an Autoencoder to reconstruct the missing pieces.

I remember the first time I applied this to a dataset of satellite imagery for a logistics firm in Chicago; the results in feature extraction were significantly better than standard supervised learning.

In this tutorial, I will show you how to build a robust Masked Image Modeling pipeline using Python and Keras.

Set Up the Keras Environment for Image Modeling

Before we get into the architecture, we need to ensure our environment is ready with the necessary deep learning libraries.

I usually stick with the latest TensorFlow/Keras releases because they offer the best support for custom layers and preprocessing.

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

Load the Image Dataset Using Keras

For this example, we will use the CIFAR-10 dataset, which is a staple for testing reconstruction tasks in the Python community.

I prefer this dataset for tutorials because it’s lightweight enough to run on a standard laptop while being complex enough to show real results.

# Loading the CIFAR-10 dataset for our modeling task
(x_train, _), (x_test, _) = keras.datasets.cifar10.load_data()

# Normalizing the pixel values to be between 0 and 1
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

Method 1: Create a Random Square Masking Function

The most common way I implement MIM is by creating a function that drops a random square “patch” from the image.

def apply_square_mask(images, mask_size=8):
    # Create a copy so original data is not modified
    masked_images = images.copy()
    num_images, height, width, channels = images.shape
    
    for i in range(num_images):
        # Random position for the mask
        y = np.random.randint(0, height - mask_size)
        x = np.random.randint(0, width - mask_size)

        # Apply black square mask
        masked_images[i, y:y+mask_size, x:x+mask_size, :] = 0
        
    return masked_images

x_train_masked = apply_square_mask(x_train)

# Simple console output
print("Original pixel value (before mask):", x_train[0][0][0])
print("Masked pixel value (after mask):", x_train_masked[0][0][0])

# Count masked pixels in the first image
num_masked_pixels = np.sum(x_train_masked[0] == 0)
print("Number of masked pixels in first image:", num_masked_pixels)

I executed the above example code and added the screenshot below.

Masked Image Modeling with Keras Autoencoders

This forces the Autoencoder to understand the spatial context of the pixels surrounding the hole, enabling it to fill it in accurately.

Method 2: Design the Keras Autoencoder Architecture

I’ve designed this Autoencoder with a symmetrical “Bottleneck” structure to compress image features before expanding them.

def build_keras_autoencoder(input_shape=(32, 32, 3)):
    # Encoder
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2), padding='same')(x)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    encoded = layers.MaxPooling2D((2, 2), padding='same')(x)

    # Decoder
    x = layers.Conv2DTranspose(64, (3, 3), strides=2, activation='relu', padding='same')(encoded)
    x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation='relu', padding='same')(x)
    decoded = layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)

    return keras.Model(inputs, decoded)

# Build the autoencoder
autoencoder = build_keras_autoencoder()

# Compile the model
autoencoder.compile(optimizer='adam', loss='mse')

print("Autoencoder model built successfully!")
print("Input shape:", autoencoder.input_shape)
print("Output shape:", autoencoder.output_shape)

# Print model summary
autoencoder.summary() 

I executed the above example code and added the screenshot below.

Implement Masked Image Modeling with Keras Autoencoders

The encoder captures the “essence” of the image, while the decoder uses that information to predict what the masked pixels should have been.

Method 3: Train the Model for Image Reconstruction

During training, we provide the “masked” images as the input and the “original” clean images as the target labels.

# Fitting the model using the masked images as inputs and original as targets
autoencoder.fit(
    x_train_masked, x_train,
    epochs=20,
    batch_size=128,
    shuffle=True,
    validation_data=(x_test_masked, x_test)
)

I find that using Mean Squared Error (MSE) loss works best here because it penalizes the model based on the pixel-by-pixel difference.

Visualize the Results with Matplotlib

It is always rewarding to see how well the model reconstructs the hidden parts of the image after training.

def visualize_mim_results(original, masked, predicted, n=5):
    plt.figure(figsize=(12, 6))
    for i in range(n):
        # Displaying the Masked Input
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(masked[i])
        plt.title("Masked")
        plt.axis("off")

        # Displaying the Reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(predicted[i])
        plt.title("Reconstructed")
        plt.axis("off")
    plt.show()

# Generating predictions on the test set
decoded_imgs = autoencoder.predict(x_test_masked)
visualize_mim_results(x_test, x_test_masked, decoded_imgs)

I’ve written a small script here to plot the masked input alongside the model’s predicted “clean” output.

Method 4: Implement Grid-Based Patch Masking

Sometimes, a single square mask isn’t enough to force the model to learn global features, so I use a grid-masking approach.

def apply_grid_mask(images, patch_size=4, mask_ratio=0.25):
    masked_images = images.copy()
    batch, h, w, c = images.shape
    
    for i in range(batch):
        for y in range(0, h, patch_size):
            for x in range(0, w, patch_size):
                # Randomly deciding whether to mask this specific patch
                if np.random.rand() < mask_ratio:
                    masked_images[i, y:y+patch_size, x:x+patch_size, :] = 0
    return masked_images

# Applying grid mask to see different reconstruction challenges
x_train_grid = apply_grid_mask(x_train)
x_test_grid = apply_grid_mask(x_test)

This method divides the image into a grid and randomly drops multiple patches, which is a common practice in modern Vision Transformers.

Compare MSE Performance Across Different Masking Methods

I always recommend comparing the loss values between different masking strategies to see which one challenges your model more.

# Evaluating the model performance on both types of masked data
score_square = autoencoder.evaluate(x_test_masked, x_test, verbose=0)
score_grid = autoencoder.evaluate(x_test_grid, x_test, verbose=0)

print(f"Square Mask MSE: {score_square}")
print(f"Grid Mask MSE: {score_grid}")

In my experience, grid masking usually results in a slightly higher initial loss but leads to better feature generalization in the long run.

In this tutorial, I showed you how to use Keras to implement Masked Image Modeling using Autoencoders.

This technique is a powerful way to train models in a self-supervised manner, especially when you have a lot of data but very few labels.

I hope you found this useful and can apply it to your own computer vision projects.

Other Python Tutorials you might like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.