Point Cloud Segmentation with PointNet in Keras

Segmenting 3D point clouds is a key task in computer vision and robotics. From my experience as a Python Keras developer, PointNet provides a simple yet powerful architecture to handle unordered point sets directly.

In this tutorial, I will guide you through building a PointNet model for point cloud segmentation using Keras. You’ll get full code examples for every step, making it easy to follow and apply.

What is PointNet and Why Use Keras?

PointNet is a deep learning architecture designed specifically for point clouds. It respects the unordered nature of point sets and learns features directly from raw 3D coordinates.

Using Python Keras, we can easily implement PointNet’s layers and train it for segmentation tasks on 3D data.

Method 1: Build the PointNet Model Architecture in Keras

This method shows how to implement the core PointNet architecture for segmentation.

Step 1: Import Required Libraries

Import the essential libraries needed to build and train the PointNet model.

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

Step 2: Define the Input Transformation Network (T-Net)

Create the T-Net module that learns spatial alignment for input point clouds.

def tnet(inputs, k=3):
    x = layers.Conv1D(64, 1, activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Conv1D(128, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv1D(1024, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalMaxPooling1D()(x)

    x = layers.Dense(512, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)

    # Initialize as identity matrix
    init = tf.keras.initializers.Zeros()
    x = layers.Dense(k * k, kernel_initializer='zeros', bias_initializer=init)(x)
    x = layers.Reshape((k, k))(x)

    identity = tf.eye(k, batch_shape=[tf.shape(inputs)[0]])
    x = layers.Add()([x, identity])
    return x

Step 3: Build the PointNet Segmentation Model

Assemble the full PointNet segmentation model combining transforms and MLPs.

def create_pointnet_segmentation(num_points=2048, num_classes=4):
    inputs = layers.Input(shape=(num_points, 3))

    # Input transform
    tnet1 = tnet(inputs, k=3)
    x = layers.Dot(axes=(2,1))([inputs, tnet1])

    # MLP (64, 64)
    x = layers.Conv1D(64, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv1D(64, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)

    # Feature transform
    tnet2 = tnet(x, k=64)
    x = layers.Dot(axes=(2,1))([x, tnet2])

    # MLP (64, 128, 1024)
    x = layers.Conv1D(64, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv1D(128, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv1D(1024, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)

    # Global feature
    global_feat = layers.GlobalMaxPooling1D()(x)
    global_feat = layers.RepeatVector(num_points)(global_feat)

    # Concatenate global and local features
    x = layers.Concatenate()([x, global_feat])

    # MLP for segmentation
    x = layers.Conv1D(512, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv1D(256, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv1D(128, 1, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv1D(num_classes, 1, activation='softmax')(x)

    model = models.Model(inputs=inputs, outputs=x)
    return model

model = create_pointnet_segmentation()
model.summary()

You can see the output in the screenshot below.

Point Cloud Segmentation with PointNet Keras

This setup builds a complete PointNet architecture ready for segmentation tasks.

Method 2: Train the PointNet Model on Dummy Data

This method shows how to prepare dummy point cloud data and train the model.

Step 1: Generate Dummy Point Cloud Data

Create synthetic point cloud data and labels for training PointNet.

num_points = 2048
num_classes = 4
num_samples = 100

# Random 3D points
X_train = np.random.rand(num_samples, num_points, 3).astype(np.float32)

# Random labels for each point (segmentation)
y_train = np.random.randint(0, num_classes, size=(num_samples, num_points))
y_train = tf.keras.utils.to_categorical(y_train, num_classes=num_classes)

Step 2: Compile and Train the Model

Compile the segmentation model and train it using the dummy dataset.

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.fit(X_train, y_train, epochs=10, batch_size=8)

You can see the output in the screenshot below.

Point Cloud Segmentation with PointNet in Keras

This method lets you quickly verify training behavior using artificial data.

Tips for Effective Point Cloud Segmentation in Keras

  • Normalize point clouds to zero mean and unit sphere for better training.
  • Use data augmentation like random jittering and rotation on points.
  • Experiment with batch size and learning rate for stable convergence.

PointNet is a powerful architecture for point cloud segmentation, and implementing it with Python Keras is easy. The methods I shared give you a solid foundation to build and train your own segmentation models.

Feel free to customize the architecture or try it on real 3D datasets for improved results. If you want help with dataset preparation or advanced training techniques, just ask!

Other Python Keras articles you may also like:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.