Enhance Dull Photos Using Zero-DCE in Keras

In my years of developing computer vision apps for startups in New York, I have often struggled with dark, grainy photos that ruin the user experience. Traditional filters usually make things look worse, so I turned to Zero-Reference Deep Curve Estimation (Zero-DCE) to fix these lighting issues properly.

Zero-DCE is a game-changer because it doesn’t need “perfect” photos to learn; it trains using only the dark images you already have. I found that it estimates high-order curves to adjust pixel dynamic range, making it incredibly fast and efficient for real-time mobile apps.

Build the DCE-Net Architecture in Python Keras

I usually start by building the DCE-Net, which is a lightweight CNN that predicts the best enhancement curves for each pixel. It uses seven convolutional layers with symmetrical skip connections to ensure the spatial details of your images stay sharp.

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def build_dce_net_keras():
    # Define the input layer for a flexible image size
    input_img = keras.Input(shape=[None, None, 3])
    
    # Layer 1 to 4: Extract features through standard convolutions
    conv1 = layers.Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(input_img)
    conv2 = layers.Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(conv1)
    conv3 = layers.Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(conv2)
    conv4 = layers.Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(conv3)
    
    # Layer 5 to 7: Use skip connections to concatenate previous features
    int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
    conv5 = layers.Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(int_con1)
    int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
    conv6 = layers.Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(int_con2)
    int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
    
    # Final layer predicts 24 curve parameters (3 channels x 8 iterations)
    output = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(int_con3)
    return keras.Model(inputs=input_img, outputs=output)

# Initialize the model
model = build_dce_net_keras()
model.summary()

Implement Custom Non-Reference Loss Functions in Keras

Since we don’t have “ground truth” images, I use a combination of four specific losses to guide the model toward a well-lit result. These losses ensure the colors look natural, the exposure is balanced, and the transition between neighboring pixels remains smooth.

def color_constancy_loss_keras(x):
    # This ensures the RGB channels stay balanced based on the Gray-world hypothesis
    mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
    mr, mg, mb = mean_rgb[:, :, :, 0], mean_rgb[:, :, :, 1], mean_rgb[:, :, :, 2]
    d_rg = tf.square(mr - mg)
    d_rb = tf.square(mr - mb)
    d_gb = tf.square(mb - mg)
    return tf.sqrt(tf.square(d_rg) + tf.square(d_rb) + tf.square(d_gb))

def exposure_loss_keras(x, mean_val=0.6):
    # I set the target brightness to 0.6 to keep the image from looking washed out
    x = tf.reduce_mean(x, axis=3, keepdims=True)
    mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
    return tf.reduce_mean(tf.square(mean - mean_val))

def illumination_smoothness_loss_keras(x):
    # This prevents artifacts by keeping the estimated curves smooth across the image
    batch_size = tf.shape(x)[0]
    h_x = tf.shape(x)[1]
    w_x = tf.shape(x)[2]
    h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
    w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
    return (h_tv + w_tv) / tf.cast(batch_size, tf.float32)

Apply Iterative Enhancement Curves in Python Keras

The core magic happens when we apply the predicted curve parameters to the original dark image over multiple iterations. In my experience, 8 iterations provide the perfect balance between high-quality brightening and processing speed.

def enhance_image_keras(image, curve_params):
    # Split the 24 channels into 8 separate adjustment parameters
    iteration_results = [image]
    for i in range(8):
        alpha = curve_params[:, :, :, i * 3 : (i + 1) * 3]
        last_img = iteration_results[-1]
        # Quadratic curve formula: I_next = I + alpha * I * (1 - I)
        enhanced = last_img + alpha * last_img * (1.0 - last_img)
        iteration_results.append(enhanced)
    return iteration_results[-1]

# Example usage with a dummy dark image tensor
dummy_img = tf.random.uniform((1, 256, 256, 3))
params = model(dummy_img)
bright_img = enhance_image_keras(dummy_img, params)
print("Enhanced Image Shape:", bright_img.shape)

Train the Zero-DCE Model Using Keras Custom Loops

I typically wrap everything into a custom keras.Model subclass to make the training process as simple as calling .fit(). This approach lets me calculate all the complex losses in one go and update the gradients efficiently during each training step.

class ZeroDCEModel(keras.Model):
    def __init__(self, dce_net):
        super(ZeroDCEModel, self).__init__()
        self.dce_net = dce_net

    def train_step(self, data):
        with tf.GradientTape() as tape:
            output = self.dce_net(data)
            enhanced_img = enhance_image_keras(data, output)
            
            # Combine all the non-reference losses
            loss_col = color_constancy_loss_keras(enhanced_img)
            loss_exp = exposure_loss_keras(enhanced_img)
            loss_tv = illumination_smoothness_loss_keras(output)
            total_loss = loss_col + loss_exp + 10 * loss_tv
            
        gradients = tape.gradient(total_loss, self.dce_net.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.dce_net.trainable_weights))
        return {"total_loss": total_loss, "exp_loss": loss_exp}

# Prepare for training
zero_dce = ZeroDCEModel(model)
zero_dce.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4))
# zero_dce.fit(dataset, epochs=50) # Assuming 'dataset' contains dark images

You can see the output in the screenshot below.

Enhance Dull Photos Using Zero-DCE in Keras

Implementing Zero-DCE in Keras has saved me countless hours of manual image processing in my recent projects. I hope this guide helps you brighten your low-light images with minimal effort and professional results.

You may also like to read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.