Multiclass Semantic Segmentation Using DeepLabV3+ in Keras

Semantic segmentation is a vital task in computer vision, where the goal is to classify each pixel in an image into a category. Over the years, DeepLab models have become a popular choice for this due to their accuracy and efficiency. In this tutorial, I’ll share my firsthand experience working with DeepLabV3+ in Keras to perform multiclass semantic segmentation.

I will walk you through setting up the model, preparing the data, and training the network with complete code examples. By the end, you’ll have a solid foundation to build your own segmentation projects with Keras.

What is DeepLabV3+?

DeepLabV3+ is an advanced semantic segmentation model that combines atrous convolution and an encoder-decoder architecture to capture multi-scale contextual information effectively. It improves on previous DeepLab versions by refining object boundaries and producing more precise segmentation maps.

In Keras, implementing DeepLabV3+ requires building custom layers or using existing implementations. I will provide an easy approach you can easily adapt.

Prepare the Dataset for Multiclass Segmentation in Keras

Before diving into the model, data preparation is crucial. For multiclass segmentation, each pixel is labeled with a class index.

Here’s a simple way to load and preprocess images and their corresponding masks:

import tensorflow as tf
import numpy as np
import os

IMG_SIZE = 256
NUM_CLASSES = 5  # Example: road, car, pedestrian, building, background

def load_image_mask(image_path, mask_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = tf.cast(image, tf.float32) / 255.0

    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE), method='nearest')
    mask = tf.cast(mask, tf.int32)

    return image, mask

def data_generator(image_paths, mask_paths, batch_size=16):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    dataset = dataset.map(load_image_mask, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

This method ensures images and masks are properly resized and normalized for training.

Build DeepLabV3+ Model in Keras

I prefer building DeepLabV3+ from scratch using Keras functional API for better control and customization.

Here’s a concise implementation of DeepLabV3+ with a ResNet50 backbone:

from tensorflow.keras import layers, Model
from tensorflow.keras.applications import ResNet50

def ASPP(inputs, out_channels=256):
    shape = inputs.shape

    y1 = layers.Conv2D(out_channels, 1, padding="same", use_bias=False)(inputs)
    y1 = layers.BatchNormalization()(y1)
    y1 = layers.Activation("relu")(y1)

    y2 = layers.Conv2D(out_channels, 3, dilation_rate=6, padding="same", use_bias=False)(inputs)
    y2 = layers.BatchNormalization()(y2)
    y2 = layers.Activation("relu")(y2)

    y3 = layers.Conv2D(out_channels, 3, dilation_rate=12, padding="same", use_bias=False)(inputs)
    y3 = layers.BatchNormalization()(y3)
    y3 = layers.Activation("relu")(y3)

    y4 = layers.Conv2D(out_channels, 3, dilation_rate=18, padding="same", use_bias=False)(inputs)
    y4 = layers.BatchNormalization()(y4)
    y4 = layers.Activation("relu")(y4)

    y5 = layers.GlobalAveragePooling2D()(inputs)
    y5 = layers.Reshape((1, 1, shape[-1]))(y5)
    y5 = layers.Conv2D(out_channels, 1, padding="same", use_bias=False)(y5)
    y5 = layers.BatchNormalization()(y5)
    y5 = layers.Activation("relu")(y5)
    y5 = layers.UpSampling2D(size=(shape[1], shape[2]), interpolation='bilinear')(y5)

    y = layers.Concatenate()([y1, y2, y3, y4, y5])
    y = layers.Conv2D(out_channels, 1, padding="same", use_bias=False)(y)
    y = layers.BatchNormalization()(y)
    y = layers.Activation("relu")(y)
    return y

def DeepLabV3Plus(input_shape=(256, 256, 3), num_classes=NUM_CLASSES):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)

    layer_names = [
        'conv4_block6_2_relu',  # low-level features
        'conv5_block3_2_relu'   # high-level features
    ]
    low_level_feature = base_model.get_layer(layer_names[0]).output
    high_level_feature = base_model.get_layer(layer_names[1]).output

    x = ASPP(high_level_feature)

    x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(x)

    low_level_feature = layers.Conv2D(48, 1, padding='same', use_bias=False)(low_level_feature)
    low_level_feature = layers.BatchNormalization()(low_level_feature)
    low_level_feature = layers.Activation('relu')(low_level_feature)

    x = layers.Concatenate()([x, low_level_feature])
    x = layers.Conv2D(256, 3, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(256, 3, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(x)

    outputs = layers.Conv2D(num_classes, 1, padding='same', activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=outputs)
    return model

model = DeepLabV3Plus()
model.summary()

This method builds DeepLabV3+ with a powerful backbone and ASPP module for capturing context.

Compile and Train the DeepLabV3+ Model in Keras

Once the model is ready, compiling with the right loss and optimizer is key for multiclass segmentation.

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

For training, I use the dataset generator:

train_dataset = data_generator(train_image_paths, train_mask_paths)
val_dataset = data_generator(val_image_paths, val_mask_paths)

history = model.fit(train_dataset,
                    validation_data=val_dataset,
                    epochs=30)

This method ensures the model learns pixel-wise classification effectively.

Evaluate Model Performance

To evaluate the model, I prefer using metrics like mean Intersection over Union (mIoU) for segmentation quality.

Here’s a simple mIoU calculation:

import tensorflow.keras.backend as K

def mean_iou(y_true, y_pred):
    y_pred = K.argmax(y_pred, axis=-1)
    y_true = K.squeeze(y_true, axis=-1)
    ious = []
    for i in range(NUM_CLASSES):
        intersection = K.sum(K.cast((y_true == i) & (y_pred == i), 'float32'))
        union = K.sum(K.cast((y_true == i) | (y_pred == i), 'float32'))
        iou = intersection / (union + K.epsilon())
        ious.append(iou)
    return K.mean(K.stack(ious))

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=[mean_iou])

This method gives a more meaningful insight into segmentation accuracy than simple pixel accuracy.

Use the Trained Model for Prediction

After training, you can predict segmentation masks on new images easily:

def predict_mask(model, image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.expand_dims(image, axis=0)

    pred_mask = model.predict(image)
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = tf.squeeze(pred_mask)
    return pred_mask.numpy()

mask = predict_mask(model, 'test_image.png')

You can refer to the screenshot below to see the output.

Multiclass Semantic Segmentation Using DeepLabV3+ in Keras

This method provides pixel-wise class predictions for any input image.

I hope you found this tutorial on multiclass semantic segmentation using DeepLabV3+ in Keras straightforward and practical. The combination of a powerful model architecture and clear data handling makes it easier to apply segmentation in real-world projects.

Feel free to experiment with different backbones or augment your dataset to improve results. DeepLabV3+ is a versatile model that can be fine-tuned for various applications, from autonomous driving to medical imaging.

Yoy may also read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.