Text Classification Using Switch Transformer in Keras

I have often struggled with scaling models without making them incredibly slow to train. Standard Transformers are great, but they can become computationally expensive when you want to add more parameters.

Recently, I started using Switch Transformers to solve this problem by using a Mixture-of-Experts (MoE) routing system.

This approach allows the model to have billions of parameters while only using a fraction of the power for each token.

In this tutorial, I will show you exactly how I build and train a Switch Transformer for text classification in Keras.

Set Up Your Keras Environment for Switch Transformers

Before we dive into the architecture, I always make sure my environment is ready with the necessary libraries.

I use TensorFlow and Keras to handle the layer definitions and the training loops for our text data.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

Prepare the Dataset for Keras Text Classification

For this example, I am using a dataset of movie reviews to classify sentiment as positive or negative.

I prefer using the TextVectorization layer because it handles tokenization and padding directly within the Keras model.

vocab_size = 20000 
max_len = 200 

(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_len)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=max_len)

Create the Token and Position Embedding Layer in Keras

Every Transformer needs to know the position of words, so I built a custom embedding layer to combine word and position info.

This layer ensures that our Switch Transformer understands the order of words in our sentences.

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

Implement the Switch Feed-Forward Network in Keras

This is the “Switch” part where the model chooses which expert (a dense network) should process the current token.

I use a softmax router to pick the best expert, which keeps the computation light despite having multiple experts.

class SwitchFeedForward(layers.Layer):
    def __init__(self, embed_dim, num_experts, expert_dim):
        super().__init__()
        self.num_experts = num_experts
        self.router = layers.Dense(num_experts, activation="softmax")
        self.experts = [
            keras.Sequential([
                layers.Dense(expert_dim, activation="relu"),
                layers.Dense(embed_dim)
            ]) for _ in range(num_experts)
        ]

    def call(self, inputs):
        router_probs = self.router(inputs)
        expert_idx = tf.argmax(router_probs, axis=-1)
        expert_mask = tf.one_hot(expert_idx, self.num_experts)
        
        expert_outputs = tf.stack([expert(inputs) for expert in self.experts], axis=-1)
        expert_mask = tf.expand_dims(expert_mask, axis=-2)
        
        # Multiply expert outputs by the mask to select the active expert
        combined_output = tf.reduce_sum(expert_outputs * expert_mask, axis=-1)
        return combined_output

Build the Switch Transformer Block in Keras

Now I combine the Multi-Head Attention layer with my custom Switch Feed-Forward network into a single block.

I include Layer Normalization and residual connections to keep the gradients stable during the training process.

class SwitchTransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, num_experts, ff_dim):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.switch = SwitchFeedForward(embed_dim, num_experts, ff_dim)
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        attn_output = self.att(inputs, inputs)
        out1 = self.layernorm1(inputs + attn_output)
        switch_output = self.switch(out1)
        return self.layernorm2(out1 + switch_output)

You can see the output in the screenshot below.

Text Classification Using Switch Transformer Keras

Method 1: Standard Global Pooling Classification in Keras

In this method, I use a Global Average Pooling layer to flatten the Transformer output before the final classification.

embed_dim = 32
num_heads = 2
num_experts = 4
ff_dim = 32

inputs = layers.Input(shape=(max_len,))
embedding_layer = TokenAndPositionEmbedding(max_len, vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = SwitchTransformerBlock(embed_dim, num_heads, num_experts, ff_dim)
x = transformer_block(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20, activation="relu")(x)
outputs = layers.Dense(2, activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.summary()

You can see the output in the screenshot below.

Keras Text Classification Using Switch Transformer

This is my go-to approach for general sentiment analysis because it is fast and very effective.

Method 2: Multi-Layer Switch Transformer Stack in Keras

Sometimes a single block isn’t enough, so I stack multiple Switch Transformer blocks for complex text.

def build_stacked_switch_transformer(num_layers=2):
    inputs = layers.Input(shape=(max_len,))
    x = TokenAndPositionEmbedding(max_len, vocab_size, embed_dim)(inputs)
    
    for _ in range(num_layers):
        x = SwitchTransformerBlock(embed_dim, num_heads, num_experts, ff_dim)(x)
        
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(2, activation="softmax")(x)
    
    stacked_model = keras.Model(inputs=inputs, outputs=outputs)
    stacked_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return stacked_model

stacked_model = build_stacked_switch_transformer()

You can see the output in the screenshot below.

Text Classification Using Switch Transformer in Keras

This allows the Keras model to learn deeper representations of the language patterns in the dataset.

Train the Switch Transformer Model in Keras

Training these models requires a bit of patience, so I usually set a small batch size to fit in GPU memory.

history = model.fit(
    x_train, y_train, 
    batch_size=32, 
    epochs=5, 
    validation_data=(x_val, y_val)
)

I monitor the validation accuracy to make sure the Switch routing logic is actually improving the model’s performance.

Evaluate Keras Model Performance on Test Data

Once the training is done, I run the model against the validation set to see how it performs on unseen reviews.

results = model.evaluate(x_val, y_val, verbose=0)
print(f"Validation Loss: {results[0]}")
print(f"Validation Accuracy: {results[1]}")

This step confirms if the experts are specializing well enough to distinguish between different sentiment types.

Run Predictions with the Keras Switch Transformer

Finally, I use the model.predict function to classify new sentences and see the Switch Transformer in action.

It is always rewarding to see the model correctly identify the sentiment of a custom review I wrote.

sample_text = ["This movie was absolutely fantastic and I loved the acting!"]
# Note: In a real app, you'd use the same tokenizer/vectorizer here
# For brevity, we assume the input is pre-processed
prediction = model.predict(x_val[:1])
print(f"Prediction result: {prediction}")

In this tutorial, I showed you how to build a Switch Transformer from scratch using Keras layers.

We covered how to create a custom expert routing system and how to integrate it into a text classification pipeline.

I have found that using Switch Transformers is one of the best ways to scale my Keras projects without exploding my compute budget.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.