How to Enhance Low-Light Images Using MIRNet in Keras

Enhancing low-light images has always been a personal fascination for me. It’s like magic when you can take a dark, grainy photo from a late-night street in NYC and make it look professional.

I’ve spent over four years developing with Keras, and MIRNet is hands-down one of the most powerful architectures I’ve used for this. It handles multi-scale features beautifully without losing those sharp details we all love.

In this tutorial, I will show you how to implement low-light image enhancement using MIRNet in Keras. We’ll cover everything from building the blocks to running inference on your own dark photos.

Build Selective Kernel Feature Fusion in Keras

Selective Kernel Feature Fusion (SKFF) is the “brain” of MIRNet that chooses the best features from different resolutions. I use this to make sure the model focuses on the most relevant parts of the image.

import tensorflow as tf
from tensorflow.keras import layers

def selective_kernel_feature_fusion_keras(input_tensor_1, input_tensor_2, input_tensor_3):
    # Combine features from three different scales
    fused_tensor = layers.Add()([input_tensor_1, input_tensor_2, input_tensor_3])
    
    # Generate global descriptors using average pooling
    gap = layers.GlobalAveragePooling2D()(fused_tensor)
    gap = layers.Reshape((1, 1, gap.shape[-1]))(gap)
    
    # Compress and expand to create attention maps
    shrink = layers.Conv2D(fused_tensor.shape[-1] // 8, 1, padding='same')(gap)
    shrink = layers.Activation('relu')(shrink)
    
    # Create individual attention weights for each input
    attn_1 = layers.Conv2D(fused_tensor.shape[-1], 1, activation='softmax')(shrink)
    attn_2 = layers.Conv2D(fused_tensor.shape[-1], 1, activation='softmax')(shrink)
    attn_3 = layers.Conv2D(fused_tensor.shape[-1], 1, activation='softmax')(shrink)
    
    # Multiply inputs by their respective attention maps
    out_1 = layers.Multiply()([input_tensor_1, attn_1])
    out_2 = layers.Multiply()([input_tensor_2, attn_2])
    out_3 = layers.Multiply()([input_tensor_3, attn_3])
    
    return layers.Add()([out_1, out_2, out_3])

Implement the Dual Attention Unit in Keras

The Dual Attention Unit (DAU) is what I use to refine the features by looking at both spatial and channel-wise information. It helps the model ignore noise and focus on actual image content like textures.

def dual_attention_unit_keras(input_tensor):
    # Channel attention path
    ca = layers.GlobalAveragePooling2D()(input_tensor)
    ca = layers.Reshape((1, 1, input_tensor.shape[-1]))(ca)
    ca = layers.Conv2D(input_tensor.shape[-1] // 8, 1, activation='relu')(ca)
    ca = layers.Conv2D(input_tensor.shape[-1], 1, activation='sigmoid')(ca)
    ca_out = layers.Multiply()([input_tensor, ca])
    
    # Spatial attention path
    sa = layers.Conv2D(1, 1, activation='sigmoid')(ca_out)
    sa_out = layers.Multiply()([ca_out, sa])
    
    return layers.Add()([input_tensor, sa_out])

Design the Multi-Scale Residual Block in Keras

The Multi-Scale Residual Block (MRB) is the core structural unit that processes images at multiple resolutions simultaneously. I find this approach much more effective than simple sequential layers for restoring shadows.

def multi_scale_residual_block_keras(input_tensor, channels):
    # Parallel streams for different scales
    stream_1 = layers.Conv2D(channels, 3, padding='same')(input_tensor)
    stream_1 = layers.LeakyReLU(0.2)(stream_1)
    
    stream_2 = layers.Conv2D(channels, 3, strides=2, padding='same')(input_tensor)
    stream_2 = layers.Conv2DTranspose(channels, 3, strides=2, padding='same')(stream_2)
    
    # Fuse the streams using SKFF
    fused = selective_kernel_feature_fusion_keras(stream_1, stream_2, stream_1)
    
    # Refine with Dual Attention
    refined = dual_attention_unit_keras(fused)
    
    return layers.Add()([input_tensor, refined])

Initialize the Full MIRNet Model in Keras

Now I will put all these blocks together into a complete Keras Model for end-to-end training. This architecture starts with a high-resolution input and produces an enhanced high-resolution output.

def build_mirnet_keras(input_shape=(256, 256, 3)):
    inputs = layers.Input(shape=input_shape)
    
    # Initial feature extraction
    x1 = layers.Conv2D(64, 3, padding='same')(inputs)
    
    # Deep feature processing using MRBs
    x2 = multi_scale_residual_block_keras(x1, 64)
    x3 = multi_scale_residual_block_keras(x2, 64)
    
    # Final reconstruction layer
    x4 = layers.Conv2D(3, 3, padding='same')(x3)
    outputs = layers.Add()([inputs, x4])
    
    model = tf.keras.Model(inputs, outputs)
    return model

# Create the model instance
mirnet_model = build_mirnet_keras()
mirnet_model.summary()

Train with Charbonnier Loss in Keras

Standard MSE often makes images look blurry, so I prefer using Charbonnier Loss for image restoration. It acts like an L1 loss but is differentiable at zero, which helps the model converge much faster.

def charbonnier_loss_keras(y_true, y_pred):
    # Differentiable approximation of L1 loss
    epsilon = 1e-3
    return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(epsilon)))

# Compile the model with Adam optimizer
mirnet_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=charbonnier_loss_keras,
    metrics=[tf.image.psnr]
)

Perform Image Inference in Keras

Once the model is trained, I use this method to pass a dark image and retrieve the enhanced version. It’s important to rescale the pixel values back to the standard range before saving the photo.

import numpy as np
from PIL import Image

def run_inference_keras(model, image_path):
    # Load and preprocess the image
    img = Image.open(image_path).resize((256, 256))
    img_array = np.array(img).astype('float32') / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    # Predict and post-process
    prediction = model.predict(img_array)[0]
    prediction = np.clip(prediction * 255.0, 0, 255).astype('uint8')
    
    return Image.fromarray(prediction)

# Example usage:
# enhanced_photo = run_inference_keras(mirnet_model, 'dark_chicago_street.jpg')
# enhanced_photo.show()

You can refer to the screenshot below to see the output.

Enhance Low-Light Images Using MIRNet in Keras

I hope this tutorial helps you build your own image enhancement tools. MIRNet is a fantastic piece of tech that really pushes what’s possible with deep learning and Keras.

You may also like to read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.