Text Classification Using FNet in Python with Keras

I’ve spent the last four years building deep learning models, and one thing I’ve realized is that Transformers can be quite heavy on resources.

Recently, I’ve been experimenting with FNet, a model that replaces the complex self-attention layer with a simple Fourier Transform.

In this guide, I’ll show you exactly how I use FNet for text classification to get great results without the massive computational overhead.

Set Up Your Environment for Python Keras FNet

Before we dive into the code, I always make sure my environment is ready with the latest versions of TensorFlow and Keras.

I prefer using the keras_nlp library because it has built-in support for advanced architectures like FNet, making our job much easier.

# Install the necessary library
!pip install -q --upgrade keras-nlp tensorflow
import tensorflow as tf
import keras_nlp
import keras
from tensorflow.keras import layers

Load the Sentiment Dataset for Keras Text Classification

For this tutorial, I’m using the IMDB movie review dataset, which is a classic for sentiment analysis in the US tech community.

It contains 25,000 highly polar movie reviews for training, and another 25,000 for testing, which provides a solid foundation for our model.

# Loading the IMDB dataset from Keras
batch_size = 32
raw_train_ds = tf.keras.utils.text_dataset_from_directory(
    "aclImdb/train", batch_size=batch_size, validation_split=0.2, subset="training", seed=1337
)
raw_val_ds = tf.keras.utils.text_dataset_from_directory(
    "aclImdb/train", batch_size=batch_size, validation_split=0.2, subset="validation", seed=1337
)

Text Vectorization Process in Python Keras

I use the TextVectorization layer to turn raw strings into integers, which the neural network can actually process.

I find that limiting the vocabulary to 20,000 words and the sequence length to 250 works best for most classification tasks I handle.

max_features = 20000
sequence_length = 250

# Create the vectorization layer
vectorize_layer = layers.TextVectorization(
    max_tokens=max_features,
    output_mode="int",
    output_sequence_length=sequence_length,
)

# Train the layer on our text data
train_text = raw_train_ds.map(lambda x, y: x)
vectorize_layer.adapt(train_text)

Map the Datasets for Efficient Keras Training

Once the vectorizer is ready, I apply it to my training and validation sets to ensure the data flows smoothly into the FNet model.

I also use .cache() and .prefetch() to speed up the data loading process, which is a trick I use to save time during long training sessions.

def vectorize_text(text, label):
    text = tf.expand_dims(text, -1)
    return vectorize_layer(text), label

train_ds = raw_train_ds.map(vectorize_text).cache().prefetch(buffer_size=10)
val_ds = raw_val_ds.map(vectorize_text).cache().prefetch(buffer_size=10)

Build the FNet Encoder Block in Keras

The heart of FNet is the FNetEncoder layer, which uses a 2D Discrete Fourier Transform instead of the standard multi-head attention.

I’ve noticed that this makes the model significantly faster during both training and inference without sacrificing much accuracy.

def build_fnet_model():
    inputs = keras.Input(shape=(None,), dtype="int64")
    
    # Embedding and Positional Encoding
    x = layers.Embedding(max_features, 256)(inputs)
    x = keras_nlp.layers.PositionEmbedding(sequence_length=sequence_length)(x)
    
    # FNet Encoder Block
    x = keras_nlp.layers.FNetEncoder(intermediate_dim=512)(x)
    
    # Global pooling and output layer
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    
    return keras.Model(inputs, outputs)

fnet_model = build_fnet_model()

Compile the FNet Model with Python Keras Optimizers

I usually go with the Adam optimizer for text classification because it handles sparse gradients effectively.

Since this is a binary classification task (positive vs. negative sentiment), I use binary_crossentropy as my loss function.

fnet_model.compile(
    optimizer="adam", 
    loss="binary_crossentropy", 
    metrics=["accuracy"]
)

fnet_model.summary()

Train the FNet Classifier in Keras

Now, I start the training process, typically running for about 5 to 10 epochs, depending on the convergence I see in the logs.

I always monitor the validation accuracy to make sure the model isn’t overfitting on the movie review text.

# Training the model
epochs = 5
history = fnet_model.fit(
    train_ds, 
    validation_data=val_ds, 
    epochs=epochs
)

Evaluate Performance of FNet in Python Keras

After training, I evaluate the model on the validation set to see the final accuracy score.

In my experience, FNet gets very close to Transformer-level accuracy while being much more “lightweight” for deployment.

# Check final loss and accuracy
loss, accuracy = fnet_model.evaluate(val_ds)
print(f"Validation Accuracy: {accuracy:.2%}")

Run Predictions with the FNet Keras Model

Finally, I like to test the model with custom strings to see if it correctly identifies the sentiment of a review.

I simply wrap the input in a list, pass it through the vectorizer, and then call the predict method on my model.

# Test with a custom review
sample_review = ["This movie was an absolute masterpiece with great acting!"]
vectorized_sample = vectorize_layer(sample_review)
prediction = fnet_model.predict(vectorized_sample)

print(f"Sentiment Score: {prediction[0][0]} (Closer to 1 is Positive)")

You can see the output in the screenshot below.

Text Classification Using FNet in Python with Keras

Alternative Method: Use Tokenizer with FNet

Sometimes I prefer using a dedicated Tokenizer instead of a Vectorization layer for more control over the subword units.

This method involves using keras_nlp.tokenizers, which can be more robust when dealing with slang or technical jargon in the text.

tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
    vocabulary=vectorize_layer.get_vocabulary(),
    lowercase=True
)

# Example of manual tokenization
encoded_tokens = tokenizer("I loved the cinematography!")
print(encoded_tokens)

Summary of FNet Advantages in Keras

One of the main reasons I chose FNet over traditional Transformers is the speed; it is roughly 7x faster on CPUs.

It’s an excellent choice when you need to deploy a model on edge devices or when you have a limited budget for GPU resources.

In this tutorial, I’ve shown you how to prepare your data, build the FNet architecture, and train it for sentiment classification.

I hope this helps you build faster and more efficient NLP models in your own Python projects.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.