Enhancing low-light images has always been a personal fascination for me. It’s like magic when you can take a dark, grainy photo from a late-night street in NYC and make it look professional.
I’ve spent over four years developing with Keras, and MIRNet is hands-down one of the most powerful architectures I’ve used for this. It handles multi-scale features beautifully without losing those sharp details we all love.
In this tutorial, I will show you how to implement low-light image enhancement using MIRNet in Keras. We’ll cover everything from building the blocks to running inference on your own dark photos.
Build Selective Kernel Feature Fusion in Keras
Selective Kernel Feature Fusion (SKFF) is the “brain” of MIRNet that chooses the best features from different resolutions. I use this to make sure the model focuses on the most relevant parts of the image.
import tensorflow as tf
from tensorflow.keras import layers
def selective_kernel_feature_fusion_keras(input_tensor_1, input_tensor_2, input_tensor_3):
# Combine features from three different scales
fused_tensor = layers.Add()([input_tensor_1, input_tensor_2, input_tensor_3])
# Generate global descriptors using average pooling
gap = layers.GlobalAveragePooling2D()(fused_tensor)
gap = layers.Reshape((1, 1, gap.shape[-1]))(gap)
# Compress and expand to create attention maps
shrink = layers.Conv2D(fused_tensor.shape[-1] // 8, 1, padding='same')(gap)
shrink = layers.Activation('relu')(shrink)
# Create individual attention weights for each input
attn_1 = layers.Conv2D(fused_tensor.shape[-1], 1, activation='softmax')(shrink)
attn_2 = layers.Conv2D(fused_tensor.shape[-1], 1, activation='softmax')(shrink)
attn_3 = layers.Conv2D(fused_tensor.shape[-1], 1, activation='softmax')(shrink)
# Multiply inputs by their respective attention maps
out_1 = layers.Multiply()([input_tensor_1, attn_1])
out_2 = layers.Multiply()([input_tensor_2, attn_2])
out_3 = layers.Multiply()([input_tensor_3, attn_3])
return layers.Add()([out_1, out_2, out_3])Implement the Dual Attention Unit in Keras
The Dual Attention Unit (DAU) is what I use to refine the features by looking at both spatial and channel-wise information. It helps the model ignore noise and focus on actual image content like textures.
def dual_attention_unit_keras(input_tensor):
# Channel attention path
ca = layers.GlobalAveragePooling2D()(input_tensor)
ca = layers.Reshape((1, 1, input_tensor.shape[-1]))(ca)
ca = layers.Conv2D(input_tensor.shape[-1] // 8, 1, activation='relu')(ca)
ca = layers.Conv2D(input_tensor.shape[-1], 1, activation='sigmoid')(ca)
ca_out = layers.Multiply()([input_tensor, ca])
# Spatial attention path
sa = layers.Conv2D(1, 1, activation='sigmoid')(ca_out)
sa_out = layers.Multiply()([ca_out, sa])
return layers.Add()([input_tensor, sa_out])Design the Multi-Scale Residual Block in Keras
The Multi-Scale Residual Block (MRB) is the core structural unit that processes images at multiple resolutions simultaneously. I find this approach much more effective than simple sequential layers for restoring shadows.
def multi_scale_residual_block_keras(input_tensor, channels):
# Parallel streams for different scales
stream_1 = layers.Conv2D(channels, 3, padding='same')(input_tensor)
stream_1 = layers.LeakyReLU(0.2)(stream_1)
stream_2 = layers.Conv2D(channels, 3, strides=2, padding='same')(input_tensor)
stream_2 = layers.Conv2DTranspose(channels, 3, strides=2, padding='same')(stream_2)
# Fuse the streams using SKFF
fused = selective_kernel_feature_fusion_keras(stream_1, stream_2, stream_1)
# Refine with Dual Attention
refined = dual_attention_unit_keras(fused)
return layers.Add()([input_tensor, refined])Initialize the Full MIRNet Model in Keras
Now I will put all these blocks together into a complete Keras Model for end-to-end training. This architecture starts with a high-resolution input and produces an enhanced high-resolution output.
def build_mirnet_keras(input_shape=(256, 256, 3)):
inputs = layers.Input(shape=input_shape)
# Initial feature extraction
x1 = layers.Conv2D(64, 3, padding='same')(inputs)
# Deep feature processing using MRBs
x2 = multi_scale_residual_block_keras(x1, 64)
x3 = multi_scale_residual_block_keras(x2, 64)
# Final reconstruction layer
x4 = layers.Conv2D(3, 3, padding='same')(x3)
outputs = layers.Add()([inputs, x4])
model = tf.keras.Model(inputs, outputs)
return model
# Create the model instance
mirnet_model = build_mirnet_keras()
mirnet_model.summary()Train with Charbonnier Loss in Keras
Standard MSE often makes images look blurry, so I prefer using Charbonnier Loss for image restoration. It acts like an L1 loss but is differentiable at zero, which helps the model converge much faster.
def charbonnier_loss_keras(y_true, y_pred):
# Differentiable approximation of L1 loss
epsilon = 1e-3
return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(epsilon)))
# Compile the model with Adam optimizer
mirnet_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
loss=charbonnier_loss_keras,
metrics=[tf.image.psnr]
)Perform Image Inference in Keras
Once the model is trained, I use this method to pass a dark image and retrieve the enhanced version. It’s important to rescale the pixel values back to the standard range before saving the photo.
import numpy as np
from PIL import Image
def run_inference_keras(model, image_path):
# Load and preprocess the image
img = Image.open(image_path).resize((256, 256))
img_array = np.array(img).astype('float32') / 255.0
img_array = np.expand_dims(img_array, axis=0)
# Predict and post-process
prediction = model.predict(img_array)[0]
prediction = np.clip(prediction * 255.0, 0, 255).astype('uint8')
return Image.fromarray(prediction)
# Example usage:
# enhanced_photo = run_inference_keras(mirnet_model, 'dark_chicago_street.jpg')
# enhanced_photo.show()You can refer to the screenshot below to see the output.

I hope this tutorial helps you build your own image enhancement tools. MIRNet is a fantastic piece of tech that really pushes what’s possible with deep learning and Keras.
You may also like to read:
- Highly Accurate Boundary Segmentation Using BASNet in Keras
- Image Segmentation Using Composable Fully-Convolutional Networks in Keras
- Mastering Object Detection with RetinaNet in Keras
- Keypoint Detection with Transfer Learning in Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.