When Recurrence Meets Transformers in Keras

As a Python Keras developer with over four years of experience, I’ve seen how sequence modeling has evolved dramatically. Recurrent Neural Networks (RNNs) have long been the go-to for time series and sequential data, but transformers have brought a revolution with their attention mechanisms.

Recently, I explored how combining recurrence with transformers can lead to compact, powerful sequence representations. This approach, known as Temporal Latent Bottlenecks, leverages the strengths of both architectures. In this article, I’ll share my firsthand experience and provide you with complete Keras code examples so you can apply this technique to your own projects.

Understand the Basics: Recurrence and Transformers in Keras

Before diving into the combined model, let’s quickly revisit the two components:

  • Recurrent Neural Networks (RNNs) are great for capturing temporal dependencies by processing data sequentially and maintaining a hidden state. In Keras, layers like SimpleRNN, LSTM, and GRU make this easy.
  • Transformers use self-attention mechanisms to model dependencies without recurrence, enabling parallel processing and capturing long-range relationships effectively. Keras provides the MultiHeadAttention layer to implement this.

Method 1: Build a Simple Recurrent Neural Network in Keras

I always start with a simple RNN to get a feel for how my sequence data behaves. Here’s a basic example that takes a sequence input and outputs a prediction.

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, Dense

# Generate dummy sequential data
X = np.random.rand(1000, 10, 1)  # 1000 samples, 10 timesteps, 1 feature
y = np.random.rand(1000, 1)      # 1000 targets

# Build the model
model = Sequential([
    SimpleRNN(32, activation='relu', input_shape=(10, 1)),
    Dense(1)
])

model.compile(optimizer='adam', loss='mse')
model.summary()

# Train the model
model.fit(X, y, epochs=5, batch_size=32)

You can see the output in the screenshot below.

When Recurrence Meets Transformers in Keras

This model is simple but effective for short sequences. However, it struggles with longer sequences due to vanishing gradients.

Method 2: Use Transformers for Sequence Modeling in Keras

Transformers handle long-range dependencies better. Here’s how to implement a basic transformer encoder block using Keras:

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LayerNormalization, Dropout, MultiHeadAttention
from tensorflow.keras.models import Model

def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Multi-head self-attention
    x = MultiHeadAttention(key_dim=head_size, num_heads=num_heads, dropout=dropout)(inputs, inputs)
    x = Dropout(dropout)(x)
    x = LayerNormalization(epsilon=1e-6)(x + inputs)

    # Feed-forward network
    x_ff = Dense(ff_dim, activation="relu")(x)
    x_ff = Dense(inputs.shape[-1])(x_ff)
    x = Dropout(dropout)(x_ff)
    x = LayerNormalization(epsilon=1e-6)(x + x_ff)
    return x

# Input layer
input_layer = Input(shape=(10, 1))
x = transformer_encoder(input_layer, head_size=64, num_heads=4, ff_dim=128, dropout=0.1)
x = tf.keras.layers.GlobalAveragePooling1D()(x)
output = Dense(1)(x)

model = Model(inputs=input_layer, outputs=output)
model.compile(optimizer='adam', loss='mse')
model.summary()

You can see the output in the screenshot below.

When Recurrence Meets Transformers Keras

This transformer encoder can handle longer sequences and capture complex patterns but may require more data and compute power.

Method 3: When Recurrence Meets Transformers — Temporal Latent Bottleneck in Keras

Combining RNNs and transformers can compress the sequence information into a latent bottleneck vector, which the transformer then refines. This hybrid approach captures temporal dependencies and global context effectively.

Here’s a complete Keras example implementing this idea:

import tensorflow as tf
from tensorflow.keras.layers import Input, SimpleRNN, Dense, LayerNormalization, Dropout, MultiHeadAttention, GlobalAveragePooling1D
from tensorflow.keras.models import Model

def temporal_latent_bottleneck(input_shape, rnn_units=32, bottleneck_dim=16, head_size=64, num_heads=4, ff_dim=128, dropout=0.1):
    inputs = Input(shape=input_shape)

    # Step 1: Recurrent layer compresses sequence into bottleneck vector
    rnn_output = SimpleRNN(rnn_units, activation='tanh')(inputs)
    bottleneck = Dense(bottleneck_dim, activation='relu')(rnn_output)

    # Step 2: Expand bottleneck to sequence format for transformer input
    expanded = tf.expand_dims(bottleneck, axis=1)  # Shape: (batch_size, 1, bottleneck_dim)

    # Step 3: Transformer encoder refines the latent representation
    x = MultiHeadAttention(key_dim=head_size, num_heads=num_heads, dropout=dropout)(expanded, expanded)
    x = Dropout(dropout)(x)
    x = LayerNormalization(epsilon=1e-6)(x + expanded)

    # Feed-forward network within transformer block
    x_ff = Dense(ff_dim, activation='relu')(x)
    x_ff = Dense(bottleneck_dim)(x_ff)
    x = Dropout(dropout)(x_ff)
    x = LayerNormalization(epsilon=1e-6)(x + x_ff)

    # Step 4: Global pooling and output layer
    x = GlobalAveragePooling1D()(x)
    outputs = Dense(1)(x)

    model = Model(inputs=inputs, outputs=outputs)
    return model

# Generate dummy data
X = np.random.rand(1000, 10, 1)
y = np.random.rand(1000, 1)

model = temporal_latent_bottleneck(input_shape=(10, 1))
model.compile(optimizer='adam', loss='mse')
model.summary()

model.fit(X, y, epochs=5, batch_size=32)

This method compresses the input sequence into a latent vector using an RNN, then applies transformer attention to this bottleneck, efficiently capturing temporal and global dependencies.

From my experience, pure RNNs can struggle with long sequences, while pure transformers can be resource-heavy and sometimes less effective on small datasets. Combining them leverages the best of both worlds, RNNs for temporal compression and transformers for rich contextualization.

Using Keras, it’s easy to experiment with these architectures. The temporal latent bottleneck model I shared is a powerful tool for sequence modeling tasks, from financial time series forecasting to sensor data analysis.

If you want to improve your sequence models, I recommend trying this hybrid approach. It balances complexity and performance and is flexible enough to customize for your specific needs.

Other Python Keras tutorials you may like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.