Text Classification with Keras Decision Forests and Pretrained Embeddings

In my years of developing with Keras, I have found that deep learning isn’t always the only answer for text.

Sometimes, combining the structural power of Decision Forests with the semantic richness of pretrained embeddings yields the best results.

I remember working on a sentiment analysis tool for a New York-based retail brand where speed was just as important as accuracy.

Using Keras Decision Forests (TF-DF) allowed me to achieve high performance without the heavy overhead of recurrent neural networks.

In this tutorial, I will show you exactly how to bridge the gap between static embeddings and random forests in a Keras workflow.

Set Up Your Keras Environment for Decision Forests

First, we need to install the necessary libraries to handle both the data processing and the forest models.

I always ensure tensorflow_decision_forests is installed alongside the core Keras library for seamless integration.

# Installing the required packages
!pip install tensorflow_decision_forests tensorflow pandas scikit-learn

Once installed, we import the modules. I prefer using pandas for data handling as it makes cleaning text much faster.

import tensorflow as tf
import tensorflow_decision_forests as tfdf
import pandas as pd
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

Load USA-Based Dataset for Keras Text Classification

For this example, imagine we are classifying customer feedback from a California-based tech support center.

We will create a synthetic dataset representing common hardware, software, and billing queries.

# Creating a sample dataset related to tech support categories
data = {
    'text': [
        "My MacBook Pro is overheating while running Photoshop in San Jose.",
        "The monthly subscription fee for the cloud service is too high.",
        "I need help resetting my password for the company portal.",
        "The latest software update is crashing on my Dell workstation.",
        "Where can I download the invoice for my recent purchase?",
        "My internet connection is dropping during Zoom calls in Austin."
    ],
    'label': [0, 1, 2, 0, 1, 0] # 0: Hardware/Tech, 1: Billing, 2: Account
}

df = pd.DataFrame(data)

I find that using real-world scenarios helps in understanding how features actually impact the model’s split points.

Vectorize Text with Keras TextVectorization

Before we apply embeddings, we must convert our raw strings into a format that Keras can understand.

The TextVectorization layer is my go-to tool because it handles tokenization and vocabulary mapping in one step.

# Defining the vectorization layer
vectorizer = layers.TextVectorization(max_tokens=1000, output_sequence_length=10)
vectorizer.adapt(df['text'].values)

# Transforming the text into sequences
integer_data = vectorizer(df['text'].values)
print(f"Tokenized sequences: {integer_data[0]}")

This layer ensures that every sentence is padded to the same length, which is crucial for the embedding lookup.

Implement Pretrained Embeddings in Keras

Pretrained embeddings like GloVe or Word2Vec capture the “meaning” of words based on billions of lines of text.

Instead of training embeddings from scratch, I use these weights to give the model a head start on vocabulary.

# Simulating a pretrained embedding matrix (e.g., GloVe)
voc = vectorizer.get_vocabulary()
num_tokens = len(voc)
embedding_dim = 50

# In a real scenario, you would load weights from a file here
embedding_matrix = np.random.uniform(-0.05, 0.05, (num_tokens, embedding_dim))

embedding_layer = layers.Embedding(
    num_tokens,
    embedding_dim,
    embeddings_initializer=keras.initializers.Constant(embedding_matrix),
    trainable=False,
)

By setting trainable=False, I prevent the model from overwriting the valuable semantic relationships already learned.

Extract Features for Keras Decision Forests

Decision Forests in Keras usually expect fixed-length numerical features rather than sequences.

I use a Global Average Pooling layer to collapse the sequence of embeddings into a single vector per document.

# Creating a feature extraction model
input_node = tf.keras.Input(shape=(None,), dtype="int64")
x = embedding_layer(input_node)
x = layers.GlobalAveragePooling1D()(x)
feature_extractor = tf.keras.Model(inputs=input_node, outputs=x)

# Extracting the numerical features
extracted_features = feature_extractor.predict(integer_data)

This step converts our text into a “dense representation” that acts as the input for our Random Forest.

Train the Keras Random Forest Model

Now we can initialize the RandomForestModel provided by TensorFlow Decision Forests.

This model is incredibly robust and often requires much less hyperparameter tuning than a standard neural network.

# Preparing the dataset for TF-DF
train_df = pd.DataFrame(extracted_features)
train_df["label"] = df["label"]

train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="label")

# Building and training the model
rf_model = tfdf.keras.RandomForestModel()
rf_model.fit(train_ds)
print("Model trained successfully!")
print(f"Training accuracy: {rf_model.score(train_df.drop(columns=['label']), train_df['label'])}")

You can refer to the screenshot below to see the output.

Text Classification with Decision Forests and Pretrained Embeddings in Keras

I love how Keras Decision Forests provide detailed logs of the tree-building process automatically.

Method 1: Use Gradient Boosted Trees in Keras

While Random Forests are great, Gradient Boosted Trees (GBDT) often provide better accuracy for complex text patterns.

I use this method when the dataset has subtle differences between categories that a simple forest might miss.

# Using Gradient Boosted Trees for higher precision
gbt_model = tfdf.keras.GradientBoostedTreesModel()
gbt_model.fit(train_ds)

# Evaluating the GBDT model
gbt_model.compile(metrics=["accuracy"])
print(gbt_model)

You can refer to the screenshot below to see the output.

Text Classification with Keras Decision Forests and Pretrained Embeddings

Gradient Boosting builds trees sequentially, correcting the errors of the previous ones, which I find very effective for noisy text.

Method 2: Hybrid Keras Functional Model

You can also wrap the feature extractor and the forest into a single pipeline for easier deployment.

This method ensures that raw tokens go in and a classification comes out without manual preprocessing steps.

# Functional approach to wrap the entire pipeline
inputs = tf.keras.Input(shape=(1,), dtype="string")
vectorized = vectorizer(inputs)
embeddings = embedding_layer(vectorized)
pooled = layers.GlobalAveragePooling1D()(embeddings)

# Note: TF-DF models can be used within Functional API with specific considerations
# Here we treat the forest as the final classifier
full_model = tfdf.keras.RandomForestModel()
# Usually, you fit this on the pooled features directly as shown in previous steps

Using the functional API makes the code cleaner and easier for other developers on my team to read.

Visualize Keras Decision Forest Results

One of the best features of using Forests over standard Deep Learning is interpretability.

I always plot the trees to see which “words” (or embedding dimensions) are driving the classification decisions.

# Visualizing one of the trees in the forest
with open("tree_visualization.html", "w") as f:
    f.write(tfdf.model_plotter.plot_model(rf_model, tree_index=0))

Being able to explain why a customer’s email was flagged as “Billing” is a huge advantage in a corporate environment.

Building a text classifier with Keras Decision Forests and pretrained embeddings is a reliable way to get high performance with less data.

I have used this specific architecture in several production environments with great success.

It combines the best of both worlds: the linguistic understanding of Word2Vec and the tabular efficiency of Decision Trees.

You may read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.