Segmenting 3D point clouds is a key task in computer vision and robotics. From my experience as a Python Keras developer, PointNet provides a simple yet powerful architecture to handle unordered point sets directly.
In this tutorial, I will guide you through building a PointNet model for point cloud segmentation using Keras. You’ll get full code examples for every step, making it easy to follow and apply.
What is PointNet and Why Use Keras?
PointNet is a deep learning architecture designed specifically for point clouds. It respects the unordered nature of point sets and learns features directly from raw 3D coordinates.
Using Python Keras, we can easily implement PointNet’s layers and train it for segmentation tasks on 3D data.
Method 1: Build the PointNet Model Architecture in Keras
This method shows how to implement the core PointNet architecture for segmentation.
Step 1: Import Required Libraries
Import the essential libraries needed to build and train the PointNet model.
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as npStep 2: Define the Input Transformation Network (T-Net)
Create the T-Net module that learns spatial alignment for input point clouds.
def tnet(inputs, k=3):
x = layers.Conv1D(64, 1, activation='relu')(inputs)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(128, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(1024, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.GlobalMaxPooling1D()(x)
x = layers.Dense(512, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.BatchNormalization()(x)
# Initialize as identity matrix
init = tf.keras.initializers.Zeros()
x = layers.Dense(k * k, kernel_initializer='zeros', bias_initializer=init)(x)
x = layers.Reshape((k, k))(x)
identity = tf.eye(k, batch_shape=[tf.shape(inputs)[0]])
x = layers.Add()([x, identity])
return xStep 3: Build the PointNet Segmentation Model
Assemble the full PointNet segmentation model combining transforms and MLPs.
def create_pointnet_segmentation(num_points=2048, num_classes=4):
inputs = layers.Input(shape=(num_points, 3))
# Input transform
tnet1 = tnet(inputs, k=3)
x = layers.Dot(axes=(2,1))([inputs, tnet1])
# MLP (64, 64)
x = layers.Conv1D(64, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(64, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
# Feature transform
tnet2 = tnet(x, k=64)
x = layers.Dot(axes=(2,1))([x, tnet2])
# MLP (64, 128, 1024)
x = layers.Conv1D(64, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(128, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(1024, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
# Global feature
global_feat = layers.GlobalMaxPooling1D()(x)
global_feat = layers.RepeatVector(num_points)(global_feat)
# Concatenate global and local features
x = layers.Concatenate()([x, global_feat])
# MLP for segmentation
x = layers.Conv1D(512, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(256, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(128, 1, activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.Conv1D(num_classes, 1, activation='softmax')(x)
model = models.Model(inputs=inputs, outputs=x)
return model
model = create_pointnet_segmentation()
model.summary()You can see the output in the screenshot below.

This setup builds a complete PointNet architecture ready for segmentation tasks.
Method 2: Train the PointNet Model on Dummy Data
This method shows how to prepare dummy point cloud data and train the model.
Step 1: Generate Dummy Point Cloud Data
Create synthetic point cloud data and labels for training PointNet.
num_points = 2048
num_classes = 4
num_samples = 100
# Random 3D points
X_train = np.random.rand(num_samples, num_points, 3).astype(np.float32)
# Random labels for each point (segmentation)
y_train = np.random.randint(0, num_classes, size=(num_samples, num_points))
y_train = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)Step 2: Compile and Train the Model
Compile the segmentation model and train it using the dummy dataset.
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(X_train, y_train, epochs=10, batch_size=8)You can see the output in the screenshot below.

This method lets you quickly verify training behavior using artificial data.
Tips for Effective Point Cloud Segmentation in Keras
- Normalize point clouds to zero mean and unit sphere for better training.
- Use data augmentation like random jittering and rotation on points.
- Experiment with batch size and learning rate for stable convergence.
PointNet is a powerful architecture for point cloud segmentation, and implementing it with Python Keras is easy. The methods I shared give you a solid foundation to build and train your own segmentation models.
Feel free to customize the architecture or try it on real 3D datasets for improved results. If you want help with dataset preparation or advanced training techniques, just ask!
Other Python Keras articles you may also like:
- Mastering Object Detection with RetinaNet in Keras
- Keypoint Detection with Transfer Learning in Keras
- Object Detection Using Vision Transformers in Keras
- Monocular Depth Estimation Using Keras

I am Bijay Kumar, a Microsoft MVP in SharePoint. Apart from SharePoint, I started working on Python, Machine learning, and artificial intelligence for the last 5 years. During this time I got expertise in various Python libraries also like Tkinter, Pandas, NumPy, Turtle, Django, Matplotlib, Tensorflow, Scipy, Scikit-Learn, etc… for various clients in the United States, Canada, the United Kingdom, Australia, New Zealand, etc. Check out my profile.