Image classification is a fundamental task in computer vision. Over the years, transfer learning has become a popular approach to improve model performance without needing massive datasets. One of the most powerful transfer learning methods is BigTransfer (BiT).
I’ve worked extensively with Python Keras, and in this article, I’ll walk you through how to use BiT for image classification with practical, ready-to-use code.
What is BigTransfer (BiT)?
BigTransfer, or BiT, is a transfer learning method that leverages large-scale pre-trained models to improve image classification tasks. It uses representations learned on huge datasets and transfers them to your target problem, making training faster and more accurate. In my experience, it simplifies hyperparameter tuning and boosts sample efficiency.
Use Python Keras for BiT
Python Keras offers an easy-to-use API for deep learning. It integrates well with TensorFlow, making it easy to implement transfer learning models like BiT. If you want to build image classifiers quickly and efficiently, Keras is my go-to tool.
Set Up the Environment
Before we get into the code, make sure you have the necessary libraries installed.
!pip install tensorflow tensorflow_hub tensorflow_datasetsThis installs TensorFlow, TensorFlow Hub (which hosts pre-trained models like BiT), and TensorFlow Datasets.
Load BigTransfer (BiT) Pretrained Model in Keras
TensorFlow Hub provides BiT models ready for transfer learning. Here’s how to load a BiT model and prepare it for your classification task.
import tensorflow as tf
import tensorflow_hub as hub
def load_bit_model(num_classes):
# Load BiT model from TensorFlow Hub
bit_model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
base_model = hub.KerasLayer(bit_model_url, trainable=False, input_shape=(224, 224, 3))
# Build the classification head
model = tf.keras.Sequential([
base_model,
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(num_classes, activation='softmax')
])
return modelThis function loads the BiT ResNet-50 model and adds a custom classification head with dropout for regularization.
Prepare the Dataset for Image Classification
For demonstration, I use the TensorFlow Flowers dataset, which contains images of flowers classified into multiple categories. You can replace it with your dataset.
import tensorflow_datasets as tfds
def prepare_dataset(batch_size=32):
(train_ds, val_ds), ds_info = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:]'],
as_supervised=True,
with_info=True
)
def preprocess(image, label):
image = tf.image.resize(image, (224, 224))
image = image / 255.0 # Normalize to [0,1]
return image, label
train_ds = train_ds.map(preprocess).shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess).batch(batch_size).prefetch(tf.data.AUTOTUNE)
num_classes = ds_info.features['label'].num_classes
return train_ds, val_ds, num_classesThis method loads and preprocesses images by resizing and normalizing, then batches them for training.
Compile and Train the BiT Model with Keras
Now, let’s compile the model with an optimizer and loss function suitable for multi-class classification, and train it.
def compile_and_train(model, train_ds, val_ds, epochs=5):
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
return historyThis function compiles the model using the Adam optimizer and trains it, returning the training history.
Evaluate Model Performance
After training, it’s important to evaluate how well the model performs on validation data.
def evaluate_model(model, val_ds):
loss, accuracy = model.evaluate(val_ds)
print(f"Validation Loss: {loss:.4f}")
print(f"Validation Accuracy: {accuracy:.4f}")This simple method prints out the loss and accuracy on the validation set.
Put It All Together: Full Workflow
Here’s the complete script combining all the parts:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
def load_bit_model(num_classes):
bit_model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
base_model = hub.KerasLayer(bit_model_url, trainable=False, input_shape=(224, 224, 3))
model = tf.keras.Sequential([
base_model,
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(num_classes, activation='softmax')
])
return model
def prepare_dataset(batch_size=32):
(train_ds, val_ds), ds_info = tfds.load(
'tf_flowers',
split=['train[:80%]', 'train[80%:]'],
as_supervised=True,
with_info=True
)
def preprocess(image, label):
image = tf.image.resize(image, (224, 224))
image = image / 255.0
return image, label
train_ds = train_ds.map(preprocess).shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess).batch(batch_size).prefetch(tf.data.AUTOTUNE)
num_classes = ds_info.features['label'].num_classes
return train_ds, val_ds, num_classes
def compile_and_train(model, train_ds, val_ds, epochs=5):
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)
return history
def evaluate_model(model, val_ds):
loss, accuracy = model.evaluate(val_ds)
print(f"Validation Loss: {loss:.4f}")
print(f"Validation Accuracy: {accuracy:.4f}")
if __name__ == "__main__":
train_ds, val_ds, num_classes = prepare_dataset()
model = load_bit_model(num_classes)
compile_and_train(model, train_ds, val_ds)
evaluate_model(model, val_ds)I executed the above example code and added the screenshot below.

Run this script to train a BiT-based image classifier on the flowers dataset. You can customize the dataset and model parameters for your own use case.
Using BiT with Python Keras leverages powerful pre-trained features, enabling you to train models with less data and time. The modular design of Keras makes it easy to swap out components, such as the classification head or optimizer. In my projects, this approach consistently produces strong results with minimal tuning.
Image classification with BigTransfer in Python Keras is easy and effective. By following this guide, you can easily build your own high-performing image classifiers. If you have questions or want to share your experience, feel free to comment below.
If you want to explore more about transfer learning and advanced image classification techniques, keep following PythonGuides.com for expert tutorials.
You may also like to read:
- Train a Vision Transformer on Small Datasets Using Keras
- Vision Transformer Without Attention Using Python Keras
- Keras Image Classification with Global Context Vision Transformer
- When Recurrence Meets Transformers in Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.