Image segmentation is a critical task in computer vision where we classify each pixel of an image to identify objects or boundaries precisely. During my years working with Keras, I found BASNet (Boundary-Aware Salient Object Detection Network) to be one of the most effective architectures for segmenting fine boundaries in images.
In this tutorial, I will walk you through how to implement BASNet using Keras. I will provide complete, ready-to-run code snippets for each step so you can easily follow along and apply this to your own projects.
What is BASNet and Why Use Keras?
BASNet is a deep learning model specifically designed for highly accurate boundary segmentation. It excels at detecting object edges with remarkable precision, which is useful in medical imaging, autonomous driving, and other fields where exact boundaries matter.
Keras, integrated with TensorFlow, provides a clean and efficient way to build and train BASNet models thanks to its modular API and ease of customization.
Set Up Your Environment
Before we start, make sure you have the following installed:
pip install tensorflow numpy matplotlib opencv-pythonThese libraries are essential for building the model, processing images, and visualizing results.
Build BASNet Architecture in Keras
BASNet is a complex architecture involving an encoder-decoder with residual blocks and boundary-aware modules. Below, I provide a simplified but functional version of BASNet in Keras.
Step 1: Import Required Libraries
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Add, Input, UpSampling2D, Concatenate
from tensorflow.keras.models import ModelStep 2: Define Residual Block
Residual blocks help the network learn better by skipping connections.
def residual_block(x, filters, kernel_size=3):
shortcut = x
x = Conv2D(filters, kernel_size, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters, kernel_size, padding='same')(x)
x = BatchNormalization()(x)
x = Add()([shortcut, x])
x = Activation('relu')(x)
return xStep 3: Define Encoder Block
The encoder downsamples and extracts features.
def encoder_block(x, filters):
x = Conv2D(filters, 3, strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = residual_block(x, filters)
return xStep 4: Define Decoder Block
The decoder upsamples features and refines boundaries.
def decoder_block(x, skip_connection, filters):
x = UpSampling2D(size=(2, 2))(x)
x = Concatenate()([x, skip_connection])
x = Conv2D(filters, 3, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = residual_block(x, filters)
return xStep 5: Assemble BASNet Model
def build_basnet(input_shape=(256, 256, 3)):
inputs = Input(shape=input_shape)
# Encoder
e1 = encoder_block(inputs, 64) # 128x128
e2 = encoder_block(e1, 128) # 64x64
e3 = encoder_block(e2, 256) # 32x32
e4 = encoder_block(e3, 512) # 16x16
# Bridge
b = Conv2D(1024, 3, padding='same')(e4)
b = BatchNormalization()(b)
b = Activation('relu')(b)
# Decoder
d4 = decoder_block(b, e4, 512) # 32x32
d3 = decoder_block(d4, e3, 256) # 64x64
d2 = decoder_block(d3, e2, 128) # 128x128
d1 = decoder_block(d2, e1, 64) # 256x256
# Output layer with sigmoid activation for binary mask prediction
outputs = Conv2D(1, 1, activation='sigmoid')(d1)
model = Model(inputs, outputs)
return modelTraining BASNet on Your Dataset
Once the model is ready, you can compile and train it on your labeled segmentation dataset.
Step 6: Compile and Train Model
model = build_basnet()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Assuming you have train_images and train_masks as numpy arrays
# train_images.shape -> (num_samples, 256, 256, 3)
# train_masks.shape -> (num_samples, 256, 256, 1)
# Example training call:
# model.fit(train_images, train_masks, batch_size=8, epochs=20, validation_split=0.1)You can replace train_images and train_masks with your actual dataset arrays. The binary cross-entropy loss is suitable for binary segmentation tasks like boundary detection.
Evaluating and Visualizing Results
After training, you want to see how well BASNet segments boundaries.
Step 7: Predict and Visualize
import matplotlib.pyplot as plt
import numpy as np
def visualize_prediction(model, image):
pred_mask = model.predict(np.expand_dims(image, axis=0))[0]
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title("Input Image")
plt.imshow(image)
plt.axis('off')
plt.subplot(1,2,2)
plt.title("Predicted Boundary Mask")
plt.imshow(pred_mask.squeeze(), cmap='gray')
plt.axis('off')
plt.show()You can see the output in the screenshot below.

Use this function by passing a test image (preprocessed to 256×256 and normalized) to quickly see the segmentation mask.
Why BASNet in Keras Works Well for Boundary Segmentation
From my experience, BASNet’s architecture focuses on preserving edge details through residual and boundary-aware decoding layers. Keras’ flexibility lets you customize and extend this model easily.
The model efficiently balances feature extraction and spatial resolution, which is key for sharp boundary detection.
Highly accurate boundary segmentation is achievable using BASNet with Keras by following the steps above. This approach can be adapted to various domains such as medical imaging, satellite imagery, and autonomous systems, where precise object boundaries are critical.
If you want to dive deeper into improving BASNet, you can experiment with advanced loss functions like IoU loss or boundary-aware losses, and augment your dataset for better generalization.
You may also read:
- When Recurrence Meets Transformers in Keras
- Image Classification with BigTransfer (BiT) Using Keras
- Image Segmentation with a U-Net-Like Architecture in Keras
- Multiclass Semantic Segmentation Using DeepLabV3+ in Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.