Knowledge Distillation in Keras

I have spent a significant amount of time building complex deep learning models that perform brilliantly but are far too heavy for mobile devices.

In my experience, Knowledge Distillation is the most effective way to shrink a massive “Teacher” model into a compact “Student” model while keeping the accuracy high.

In this tutorial, I will show you exactly how to implement Knowledge Distillation in Keras using a real-world scenario: analyzing consumer credit card transactions.

What is Keras Knowledge Distillation?

Knowledge Distillation is a technique where a small model learns to mimic the behavior of a large, pre-trained model.

I find this particularly useful when I need to deploy high-performing models on edge devices with limited memory.

Step 1: Set up your Keras environment for Distillation

Before we start, we need to ensure our environment is ready with the necessary libraries for building our neural networks.

I always recommend using the latest version of TensorFlow and Keras to ensure all distillation features are available.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

# Verify the version for compatibility
print(tf.__version__)

Step 2: Prepare a Financial Dataset for Keras Training

For this example, we will simulate a dataset representing credit card transaction features, common in US-based financial fraud detection.

I prefer using synthetic data for tutorials so you can run the code immediately without downloading massive CSV files.

# Simulating 10,000 credit card transactions with 20 features each
num_samples = 10000
x_train = np.random.random((num_samples, 20))
y_train = np.random.randint(2, size=(num_samples, 1))

x_test = np.random.random((2000, 20))
y_test = np.random.randint(2, size=(2000, 1))

Step 3: Create the Teacher Model in Keras

The Teacher model is usually a deep network that has high capacity and learns the intricate patterns within the financial data.

In my projects, I use a multi-layer dense network to ensure the Teacher is “smart” enough to guide the smaller model.

# Building a heavy Teacher model
teacher = keras.Sequential(
    [
        keras.Input(shape=(20,)),
        layers.Dense(256, activation="relu"),
        layers.Dense(128, activation="relu"),
        layers.Dense(64, activation="relu"),
        layers.Dense(1),
    ],
    name="teacher",
)

teacher.summary()

Step 4: Build the Student Model in Keras

The Student model is significantly smaller, designed for speed and low-latency inference during real-time transaction processing.

I usually design the student with fewer layers and neurons to see how much “knowledge” it can actually absorb.

# Building a lightweight Student model
student = keras.Sequential(
    [
        keras.Input(shape=(20,)),
        layers.Dense(32, activation="relu"),
        layers.Dense(16, activation="relu"),
        layers.Dense(1),
    ],
    name="student",
)

student.summary()

Step 5: Define the Distiller Class in Keras

To perform distillation, we need to override the train_step in a custom Keras class to handle the “Soft Targets” from the teacher.

I have found that manually calculating the loss between the teacher’s predictions and the student’s predictions gives the best results.

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data
        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)
            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients and update weights
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, student_predictions)
        return {m.name: m.result() for m in self.metrics}

Step 6: Initialize the Keras Distillation Process

Now that our Distiller class is ready, we need to initialize it with our models and choose a temperature for the “softening” process.

I often experiment with the temperature value; a higher temperature creates “softer” probability distributions for the student to learn from.

# Initialize and compile the distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.BinaryAccuracy()],
    student_loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

Step 7: Train the Student via Keras Distillation

This is where the magic happens: the student trains not just on the raw labels, but on the rich output of the teacher.

I have noticed that training a student this way often yields better results than training it on the data alone.

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=10)

Step 8: Evaluate Keras Model Performance

After distillation, it is critical to compare the student’s performance against the original teacher’s accuracy on the test set.

I use the standard evaluation method to ensure the student model meets the production requirements for our financial app.

# Evaluate the student on the test data
distiller.evaluate(x_test, y_test)

You can refer to the screenshot below to see the output.

Knowledge Distillation in Keras

Method 2: Distillation using Keras Feature Mapping

Another approach I frequently use is matching the intermediate feature maps of the teacher and student models.

This helps the student learn the internal “thinking process” of the teacher rather than just the final output.

# Example of matching intermediate layer outputs
def feature_loss(teacher_features, student_features):
    return tf.reduce_mean(tf.square(teacher_features - student_features))

# In a real scenario, you would extract intermediate layers using Keras Functional API

Method 3: Keras Soft Label Distillation

This simplified method focuses only on the final output probabilities (soft labels) without custom training loops.

I find this method easier to implement for quick prototypes when I don’t want to build a full custom class.

# Pre-calculate teacher predictions
teacher_preds = teacher.predict(x_train)

# Train student directly on teacher's "Soft Labels"
student.compile(optimizer='adam', loss='mse')
student.fit(x_train, teacher_preds, epochs=5)

Knowledge distillation is a powerful tool in your Keras toolkit, especially when you need to balance performance and efficiency.

I have found that by following these steps, you can significantly reduce model size while maintaining a high level of accuracy for real-world applications.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.