I have often struggled with scaling models without making them incredibly slow to train. Standard Transformers are great, but they can become computationally expensive when you want to add more parameters.
Recently, I started using Switch Transformers to solve this problem by using a Mixture-of-Experts (MoE) routing system.
This approach allows the model to have billions of parameters while only using a fraction of the power for each token.
In this tutorial, I will show you exactly how I build and train a Switch Transformer for text classification in Keras.
Set Up Your Keras Environment for Switch Transformers
Before we dive into the architecture, I always make sure my environment is ready with the necessary libraries.
I use TensorFlow and Keras to handle the layer definitions and the training loops for our text data.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as npPrepare the Dataset for Keras Text Classification
For this example, I am using a dataset of movie reviews to classify sentiment as positive or negative.
I prefer using the TextVectorization layer because it handles tokenization and padding directly within the Keras model.
vocab_size = 20000
max_len = 200
(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_len)
x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=max_len)Create the Token and Position Embedding Layer in Keras
Every Transformer needs to know the position of words, so I built a custom embedding layer to combine word and position info.
This layer ensures that our Switch Transformer understands the order of words in our sentences.
class TokenAndPositionEmbedding(layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim):
super().__init__()
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, x):
maxlen = tf.shape(x)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positionsImplement the Switch Feed-Forward Network in Keras
This is the “Switch” part where the model chooses which expert (a dense network) should process the current token.
I use a softmax router to pick the best expert, which keeps the computation light despite having multiple experts.
class SwitchFeedForward(layers.Layer):
def __init__(self, embed_dim, num_experts, expert_dim):
super().__init__()
self.num_experts = num_experts
self.router = layers.Dense(num_experts, activation="softmax")
self.experts = [
keras.Sequential([
layers.Dense(expert_dim, activation="relu"),
layers.Dense(embed_dim)
]) for _ in range(num_experts)
]
def call(self, inputs):
router_probs = self.router(inputs)
expert_idx = tf.argmax(router_probs, axis=-1)
expert_mask = tf.one_hot(expert_idx, self.num_experts)
expert_outputs = tf.stack([expert(inputs) for expert in self.experts], axis=-1)
expert_mask = tf.expand_dims(expert_mask, axis=-2)
# Multiply expert outputs by the mask to select the active expert
combined_output = tf.reduce_sum(expert_outputs * expert_mask, axis=-1)
return combined_outputBuild the Switch Transformer Block in Keras
Now I combine the Multi-Head Attention layer with my custom Switch Feed-Forward network into a single block.
I include Layer Normalization and residual connections to keep the gradients stable during the training process.
class SwitchTransformerBlock(layers.Layer):
def __init__(self, embed_dim, num_heads, num_experts, ff_dim):
super().__init__()
self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.switch = SwitchFeedForward(embed_dim, num_experts, ff_dim)
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
def call(self, inputs):
attn_output = self.att(inputs, inputs)
out1 = self.layernorm1(inputs + attn_output)
switch_output = self.switch(out1)
return self.layernorm2(out1 + switch_output)You can see the output in the screenshot below.

Method 1: Standard Global Pooling Classification in Keras
In this method, I use a Global Average Pooling layer to flatten the Transformer output before the final classification.
embed_dim = 32
num_heads = 2
num_experts = 4
ff_dim = 32
inputs = layers.Input(shape=(max_len,))
embedding_layer = TokenAndPositionEmbedding(max_len, vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = SwitchTransformerBlock(embed_dim, num_heads, num_experts, ff_dim)
x = transformer_block(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dropout(0.1)(x)
x = layers.Dense(20, activation="relu")(x)
outputs = layers.Dense(2, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.summary()You can see the output in the screenshot below.

This is my go-to approach for general sentiment analysis because it is fast and very effective.
Method 2: Multi-Layer Switch Transformer Stack in Keras
Sometimes a single block isn’t enough, so I stack multiple Switch Transformer blocks for complex text.
def build_stacked_switch_transformer(num_layers=2):
inputs = layers.Input(shape=(max_len,))
x = TokenAndPositionEmbedding(max_len, vocab_size, embed_dim)(inputs)
for _ in range(num_layers):
x = SwitchTransformerBlock(embed_dim, num_heads, num_experts, ff_dim)(x)
x = layers.GlobalAveragePooling1D()(x)
outputs = layers.Dense(2, activation="softmax")(x)
stacked_model = keras.Model(inputs=inputs, outputs=outputs)
stacked_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
return stacked_model
stacked_model = build_stacked_switch_transformer()You can see the output in the screenshot below.

This allows the Keras model to learn deeper representations of the language patterns in the dataset.
Train the Switch Transformer Model in Keras
Training these models requires a bit of patience, so I usually set a small batch size to fit in GPU memory.
history = model.fit(
x_train, y_train,
batch_size=32,
epochs=5,
validation_data=(x_val, y_val)
)I monitor the validation accuracy to make sure the Switch routing logic is actually improving the model’s performance.
Evaluate Keras Model Performance on Test Data
Once the training is done, I run the model against the validation set to see how it performs on unseen reviews.
results = model.evaluate(x_val, y_val, verbose=0)
print(f"Validation Loss: {results[0]}")
print(f"Validation Accuracy: {results[1]}")This step confirms if the experts are specializing well enough to distinguish between different sentiment types.
Run Predictions with the Keras Switch Transformer
Finally, I use the model.predict function to classify new sentences and see the Switch Transformer in action.
It is always rewarding to see the model correctly identify the sentiment of a custom review I wrote.
sample_text = ["This movie was absolutely fantastic and I loved the acting!"]
# Note: In a real app, you'd use the same tokenizer/vectorizer here
# For brevity, we assume the input is pre-processed
prediction = model.predict(x_val[:1])
print(f"Prediction result: {prediction}")In this tutorial, I showed you how to build a Switch Transformer from scratch using Keras layers.
We covered how to create a custom expert routing system and how to integrate it into a text classification pipeline.
I have found that using Switch Transformers is one of the best ways to scale my Keras projects without exploding my compute budget.
You may read:
- Active Learning for Text Classification with Python Keras
- Text Classification Using FNet in Python with Keras
- Large-Scale Multi-Label Text Classification with Keras
- Text Classification with Transformer in Python Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.