Image Classification with BigTransfer (BiT) Using Keras

Image classification is a fundamental task in computer vision. Over the years, transfer learning has become a popular approach to improve model performance without needing massive datasets. One of the most powerful transfer learning methods is BigTransfer (BiT).

I’ve worked extensively with Python Keras, and in this article, I’ll walk you through how to use BiT for image classification with practical, ready-to-use code.

What is BigTransfer (BiT)?

BigTransfer, or BiT, is a transfer learning method that leverages large-scale pre-trained models to improve image classification tasks. It uses representations learned on huge datasets and transfers them to your target problem, making training faster and more accurate. In my experience, it simplifies hyperparameter tuning and boosts sample efficiency.

Use Python Keras for BiT

Python Keras offers an easy-to-use API for deep learning. It integrates well with TensorFlow, making it easy to implement transfer learning models like BiT. If you want to build image classifiers quickly and efficiently, Keras is my go-to tool.

Set Up the Environment

Before we get into the code, make sure you have the necessary libraries installed.

!pip install tensorflow tensorflow_hub tensorflow_datasets

This installs TensorFlow, TensorFlow Hub (which hosts pre-trained models like BiT), and TensorFlow Datasets.

Load BigTransfer (BiT) Pretrained Model in Keras

TensorFlow Hub provides BiT models ready for transfer learning. Here’s how to load a BiT model and prepare it for your classification task.

import tensorflow as tf
import tensorflow_hub as hub

def load_bit_model(num_classes):
    # Load BiT model from TensorFlow Hub
    bit_model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
    base_model = hub.KerasLayer(bit_model_url, trainable=False, input_shape=(224, 224, 3))

    # Build the classification head
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])

    return model

This function loads the BiT ResNet-50 model and adds a custom classification head with dropout for regularization.

Prepare the Dataset for Image Classification

For demonstration, I use the TensorFlow Flowers dataset, which contains images of flowers classified into multiple categories. You can replace it with your dataset.

import tensorflow_datasets as tfds

def prepare_dataset(batch_size=32):
    (train_ds, val_ds), ds_info = tfds.load(
        'tf_flowers',
        split=['train[:80%]', 'train[80%:]'],
        as_supervised=True,
        with_info=True
    )

    def preprocess(image, label):
        image = tf.image.resize(image, (224, 224))
        image = image / 255.0  # Normalize to [0,1]
        return image, label

    train_ds = train_ds.map(preprocess).shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.map(preprocess).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    num_classes = ds_info.features['label'].num_classes
    return train_ds, val_ds, num_classes

This method loads and preprocesses images by resizing and normalizing, then batches them for training.

Compile and Train the BiT Model with Keras

Now, let’s compile the model with an optimizer and loss function suitable for multi-class classification, and train it.

def compile_and_train(model, train_ds, val_ds, epochs=5):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs
    )
    return history

This function compiles the model using the Adam optimizer and trains it, returning the training history.

Evaluate Model Performance

After training, it’s important to evaluate how well the model performs on validation data.

def evaluate_model(model, val_ds):
    loss, accuracy = model.evaluate(val_ds)
    print(f"Validation Loss: {loss:.4f}")
    print(f"Validation Accuracy: {accuracy:.4f}")

This simple method prints out the loss and accuracy on the validation set.

Put It All Together: Full Workflow

Here’s the complete script combining all the parts:

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

def load_bit_model(num_classes):
    bit_model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
    base_model = hub.KerasLayer(bit_model_url, trainable=False, input_shape=(224, 224, 3))
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])
    return model

def prepare_dataset(batch_size=32):
    (train_ds, val_ds), ds_info = tfds.load(
        'tf_flowers',
        split=['train[:80%]', 'train[80%:]'],
        as_supervised=True,
        with_info=True
    )
    def preprocess(image, label):
        image = tf.image.resize(image, (224, 224))
        image = image / 255.0
        return image, label
    train_ds = train_ds.map(preprocess).shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.map(preprocess).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    num_classes = ds_info.features['label'].num_classes
    return train_ds, val_ds, num_classes

def compile_and_train(model, train_ds, val_ds, epochs=5):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)
    return history

def evaluate_model(model, val_ds):
    loss, accuracy = model.evaluate(val_ds)
    print(f"Validation Loss: {loss:.4f}")
    print(f"Validation Accuracy: {accuracy:.4f}")

if __name__ == "__main__":
    train_ds, val_ds, num_classes = prepare_dataset()
    model = load_bit_model(num_classes)
    compile_and_train(model, train_ds, val_ds)
    evaluate_model(model, val_ds)

I executed the above example code and added the screenshot below.

Image Classification with BigTransfer (BiT) Keras

Run this script to train a BiT-based image classifier on the flowers dataset. You can customize the dataset and model parameters for your own use case.

Using BiT with Python Keras leverages powerful pre-trained features, enabling you to train models with less data and time. The modular design of Keras makes it easy to swap out components, such as the classification head or optimizer. In my projects, this approach consistently produces strong results with minimal tuning.

Image classification with BigTransfer in Python Keras is easy and effective. By following this guide, you can easily build your own high-performing image classifiers. If you have questions or want to share your experience, feel free to comment below.

If you want to explore more about transfer learning and advanced image classification techniques, keep following PythonGuides.com for expert tutorials.

You may also like to read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.