Semantic segmentation is a vital task in computer vision, where the goal is to classify each pixel in an image into a category. Over the years, DeepLab models have become a popular choice for this due to their accuracy and efficiency. In this tutorial, I’ll share my firsthand experience working with DeepLabV3+ in Keras to perform multiclass semantic segmentation.
I will walk you through setting up the model, preparing the data, and training the network with complete code examples. By the end, you’ll have a solid foundation to build your own segmentation projects with Keras.
What is DeepLabV3+?
DeepLabV3+ is an advanced semantic segmentation model that combines atrous convolution and an encoder-decoder architecture to capture multi-scale contextual information effectively. It improves on previous DeepLab versions by refining object boundaries and producing more precise segmentation maps.
In Keras, implementing DeepLabV3+ requires building custom layers or using existing implementations. I will provide an easy approach you can easily adapt.
Prepare the Dataset for Multiclass Segmentation in Keras
Before diving into the model, data preparation is crucial. For multiclass segmentation, each pixel is labeled with a class index.
Here’s a simple way to load and preprocess images and their corresponding masks:
import tensorflow as tf
import numpy as np
import os
IMG_SIZE = 256
NUM_CLASSES = 5 # Example: road, car, pedestrian, building, background
def load_image_mask(image_path, mask_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_png(image, channels=3)
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = tf.cast(image, tf.float32) / 255.0
mask = tf.io.read_file(mask_path)
mask = tf.image.decode_png(mask, channels=1)
mask = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE), method='nearest')
mask = tf.cast(mask, tf.int32)
return image, mask
def data_generator(image_paths, mask_paths, batch_size=16):
dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
dataset = dataset.map(load_image_mask, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return datasetThis method ensures images and masks are properly resized and normalized for training.
Build DeepLabV3+ Model in Keras
I prefer building DeepLabV3+ from scratch using Keras functional API for better control and customization.
Here’s a concise implementation of DeepLabV3+ with a ResNet50 backbone:
from tensorflow.keras import layers, Model
from tensorflow.keras.applications import ResNet50
def ASPP(inputs, out_channels=256):
shape = inputs.shape
y1 = layers.Conv2D(out_channels, 1, padding="same", use_bias=False)(inputs)
y1 = layers.BatchNormalization()(y1)
y1 = layers.Activation("relu")(y1)
y2 = layers.Conv2D(out_channels, 3, dilation_rate=6, padding="same", use_bias=False)(inputs)
y2 = layers.BatchNormalization()(y2)
y2 = layers.Activation("relu")(y2)
y3 = layers.Conv2D(out_channels, 3, dilation_rate=12, padding="same", use_bias=False)(inputs)
y3 = layers.BatchNormalization()(y3)
y3 = layers.Activation("relu")(y3)
y4 = layers.Conv2D(out_channels, 3, dilation_rate=18, padding="same", use_bias=False)(inputs)
y4 = layers.BatchNormalization()(y4)
y4 = layers.Activation("relu")(y4)
y5 = layers.GlobalAveragePooling2D()(inputs)
y5 = layers.Reshape((1, 1, shape[-1]))(y5)
y5 = layers.Conv2D(out_channels, 1, padding="same", use_bias=False)(y5)
y5 = layers.BatchNormalization()(y5)
y5 = layers.Activation("relu")(y5)
y5 = layers.UpSampling2D(size=(shape[1], shape[2]), interpolation='bilinear')(y5)
y = layers.Concatenate()([y1, y2, y3, y4, y5])
y = layers.Conv2D(out_channels, 1, padding="same", use_bias=False)(y)
y = layers.BatchNormalization()(y)
y = layers.Activation("relu")(y)
return y
def DeepLabV3Plus(input_shape=(256, 256, 3), num_classes=NUM_CLASSES):
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
layer_names = [
'conv4_block6_2_relu', # low-level features
'conv5_block3_2_relu' # high-level features
]
low_level_feature = base_model.get_layer(layer_names[0]).output
high_level_feature = base_model.get_layer(layer_names[1]).output
x = ASPP(high_level_feature)
x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(x)
low_level_feature = layers.Conv2D(48, 1, padding='same', use_bias=False)(low_level_feature)
low_level_feature = layers.BatchNormalization()(low_level_feature)
low_level_feature = layers.Activation('relu')(low_level_feature)
x = layers.Concatenate()([x, low_level_feature])
x = layers.Conv2D(256, 3, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(256, 3, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(x)
outputs = layers.Conv2D(num_classes, 1, padding='same', activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=outputs)
return model
model = DeepLabV3Plus()
model.summary()This method builds DeepLabV3+ with a powerful backbone and ASPP module for capturing context.
Compile and Train the DeepLabV3+ Model in Keras
Once the model is ready, compiling with the right loss and optimizer is key for multiclass segmentation.
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])For training, I use the dataset generator:
train_dataset = data_generator(train_image_paths, train_mask_paths)
val_dataset = data_generator(val_image_paths, val_mask_paths)
history = model.fit(train_dataset,
validation_data=val_dataset,
epochs=30)This method ensures the model learns pixel-wise classification effectively.
Evaluate Model Performance
To evaluate the model, I prefer using metrics like mean Intersection over Union (mIoU) for segmentation quality.
Here’s a simple mIoU calculation:
import tensorflow.keras.backend as K
def mean_iou(y_true, y_pred):
y_pred = K.argmax(y_pred, axis=-1)
y_true = K.squeeze(y_true, axis=-1)
ious = []
for i in range(NUM_CLASSES):
intersection = K.sum(K.cast((y_true == i) & (y_pred == i), 'float32'))
union = K.sum(K.cast((y_true == i) | (y_pred == i), 'float32'))
iou = intersection / (union + K.epsilon())
ious.append(iou)
return K.mean(K.stack(ious))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=[mean_iou])This method gives a more meaningful insight into segmentation accuracy than simple pixel accuracy.
Use the Trained Model for Prediction
After training, you can predict segmentation masks on new images easily:
def predict_mask(model, image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_png(image, channels=3)
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = tf.cast(image, tf.float32) / 255.0
image = tf.expand_dims(image, axis=0)
pred_mask = model.predict(image)
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = tf.squeeze(pred_mask)
return pred_mask.numpy()
mask = predict_mask(model, 'test_image.png')You can refer to the screenshot below to see the output.

This method provides pixel-wise class predictions for any input image.
I hope you found this tutorial on multiclass semantic segmentation using DeepLabV3+ in Keras straightforward and practical. The combination of a powerful model architecture and clear data handling makes it easier to apply segmentation in real-world projects.
Feel free to experiment with different backbones or augment your dataset to improve results. DeepLabV3+ is a versatile model that can be fine-tuned for various applications, from autonomous driving to medical imaging.
Yoy may also read:
- Keras Image Classification with Global Context Vision Transformer
- When Recurrence Meets Transformers in Keras
- Image Classification with BigTransfer (BiT) Using Keras
- Image Segmentation with a U-Net-Like Architecture in Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.