Image Segmentation with a U-Net-Like Architecture in Keras

Image segmentation is one of the most exciting tasks in computer vision. It involves classifying each pixel in an image to identify objects or regions of interest. I have worked on various image processing projects using Python, and one architecture that consistently stands out for segmentation tasks is the U-Net.

U-Net was originally designed for biomedical image segmentation, but its versatility makes it perfect for many applications, including urban planning, autonomous driving, and medical imaging, all highly relevant in the USA’s booming tech landscape.

In this tutorial, I will walk you through building a U-Net-like image segmentation model using Keras and Python.

What is Image Segmentation and Why Use U-Net?

Image segmentation splits an image into meaningful parts, making it easier to analyze. For example, in medical imaging, segmenting tumors from healthy tissues helps doctors make accurate diagnoses.

U-Net stands out because it combines a contracting path to capture context and a symmetric expanding path that enables precise localization. This encoder-decoder structure with skip connections preserves spatial information, which is critical for pixel-level tasks.

From my experience, U-Net’s architecture balances complexity and performance, making it a go-to model when working with Python and Keras for segmentation.

Set Up Your Python Environment for U-Net

Before diving into the code, ensure you have the necessary Python libraries installed:

pip install tensorflow numpy matplotlib

I recommend using TensorFlow 2.x as Keras is integrated into it, making model building seamless.

Build a U-Net-Like Architecture in Keras: Step-by-Step

Here’s how I build a U-Net model in Python using Keras. I keep the architecture simple yet effective for educational purposes.

import tensorflow as tf
from tensorflow.keras import layers, models

def unet_model(input_size=(128, 128, 3)):
    inputs = layers.Input(input_size)

    # Encoder: Contracting path
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D(pool_size=(2, 2))(c4)

    # Bottleneck
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)

    # Decoder: Expanding path
    u6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)

    u7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)

    u8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)

    u9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)

    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    return model

model = unet_model()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

This Python code defines a U-Net with four downsampling and upsampling steps. The skip connections concatenate encoder features to the decoder, helping the model learn fine details.

Prepare Your Dataset for Image Segmentation in Python

For this tutorial, you can use any labeled segmentation dataset. For example, the Oxford-IIIT Pet Dataset is popular and includes pixel-level annotations.

I typically load images and masks, resize them to 128×128 pixels for faster training, and normalize pixel values.

import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array

def load_data(image_paths, mask_paths, img_size=(128, 128)):
    images = []
    masks = []
    for img_path, mask_path in zip(image_paths, mask_paths):
        img = load_img(img_path, target_size=img_size)
        img = img_to_array(img) / 255.0
        mask = load_img(mask_path, target_size=img_size, color_mode="grayscale")
        mask = img_to_array(mask) / 255.0
        images.append(img)
        masks.append(mask)
    return np.array(images), np.array(masks)

Make sure your masks are binary or multi-class encoded, depending on your problem.

Train the U-Net Model in Python

Once your data is ready, training the model is straightforward.

# Assuming X_train and Y_train are your images and masks
history = model.fit(X_train, Y_train, batch_size=16, epochs=20, validation_split=0.1)

I always monitor validation accuracy and loss to avoid overfitting. You can also add callbacks like ModelCheckpoint or EarlyStopping for better control.

Evaluate and Visualize Results

After training, visualize some predictions to see how well your model segments images.

import matplotlib.pyplot as plt

def display_prediction(model, images, masks, idx=0):
    pred_mask = model.predict(images[idx:idx+1])[0]
    pred_mask = (pred_mask > 0.5).astype(np.uint8)

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.title('Input Image')
    plt.imshow(images[idx])
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.title('True Mask')
    plt.imshow(masks[idx].squeeze(), cmap='gray')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.title('Predicted Mask')
    plt.imshow(pred_mask.squeeze(), cmap='gray')
    plt.axis('off')

    plt.show()

# Example usage
display_prediction(model, X_train, Y_train, idx=5)

You can see the output in the screenshot below.

Image Segmentation with a U-Net-Like Architecture in Keras

This approach helps you visually inspect how well the Python U-Net model segments different objects.

Alternative Methods and Tips for Image Segmentation in Python

While U-Net is powerful, you can also explore other architectures like SegNet or DeepLab, depending on your needs.

If you want a lighter model for edge devices, consider using a pretrained backbone with U-Net, such as MobileNet or ResNet, to improve efficiency.

Additionally, augment your dataset with flips, rotations, and zooms to improve generalization. Keras’ ImageDataGenerator makes this easy.

Building an image segmentation model with a U-Net-like architecture in Keras is a rewarding experience for any Python developer. The model’s design elegantly handles the balance between localization and context, making it ideal for many real-world applications in the USA, from medical imaging to autonomous vehicles.

By following this tutorial, you now have a solid foundation to build, train, and evaluate your own segmentation models using Python. Keep experimenting with different datasets and architectures to deepen your understanding.

You may also like to read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.