Scipy KDTree: Nearest Neighbor Searches in Python

Recently, I was working on a project where I needed to find the nearest points in a multidimensional dataset quickly. The brute force approach of calculating distances between all points was painfully slow with large datasets. That’s when I discovered Scipy’s KDTree, which dramatically improved my search efficiency.

In this article, I’ll share how to use KDTree to solve nearest neighbor search problems efficiently in Python. I’ll cover different query methods and practical applications that can save you hours of computation time.

So let’s get started..!

What is KDTree and Why Use It?

KDTree (K-Dimensional Tree) is a space-partitioning data structure that organizes points in a k-dimensional space. Think of it as a binary search tree extended to multiple dimensions.

The main advantage of using KDTree is its search efficiency. While a brute force approach requires O(n) operations to find nearest neighbors, KDTree can do it in approximately O(log n) time, making it ideal for large datasets.

Method 1 – Basic KDTree Implementation

Let’s start with a simple example of creating a KDTree and finding the nearest neighbor to a point:

import numpy as np
from scipy.spatial import KDTree
import matplotlib.pyplot as plt

# Create random points (e.g., locations of stores in NYC)
np.random.seed(0)
points = np.random.rand(1000, 2)  # 1000 random 2D points

# Create KDTree
tree = KDTree(points)

# Query point (e.g., your current location)
query_point = [0.5, 0.5]

# Find nearest neighbor
distance, index = tree.query(query_point)

print(f"Nearest point index: {index}")
print(f"Distance to nearest point: {distance}")
print(f"Coordinates of nearest point: {points[index]}")

Output:

Nearest point index: 478
Distance to nearest point: 0.02982175984907058
Coordinates of nearest point: [0.48006089 0.52217587]

I executed the above example code and added the screenshot below.

kdtree query

In this example, we create 1000 random 2D points, build a KDTree, and find the nearest point to [0.5, 0.5]. This could represent finding the closest store to your location in New York City.

Method 2 – Find K Nearest Neighbors

Often, you’ll want to find multiple nearest neighbors, not just one:

# Find 5 nearest neighbors
k = 5
distances, indices = tree.query(query_point, k=k)

print(f"Indices of {k} nearest points: {indices}")
print(f"Distances to {k} nearest points: {distances}")

# Visualize results
plt.figure(figsize=(10, 8))
plt.scatter(points[:, 0], points[:, 1], c='blue', alpha=0.5, label='All points')
plt.scatter(query_point[0], query_point[1], c='red', s=100, label='Query point')
plt.scatter(points[indices, 0], points[indices, 1], c='green', s=100, label=f'{k} nearest neighbors')
plt.legend()
plt.title(f"Finding {k} nearest neighbors with KDTree")
plt.show()

I executed the above example code and added the screenshot below.

scipy.spatial.kdtree

This approach is perfect for applications like restaurant recommendations, where you want to show users the 5 closest dining options to their location.

Method 3 – Radius-Based Search

Sometimes, rather than finding k nearest neighbors, you need all neighbors within a certain radius:

# Find all points within radius r
r = 0.2
indices = tree.query_ball_point(query_point, r)

print(f"Number of points within radius {r}: {len(indices)}")
print(f"Indices of points within radius {r}: {indices}")

# Visualize results
plt.figure(figsize=(10, 8))
plt.scatter(points[:, 0], points[:, 1], c='blue', alpha=0.5, label='All points')
plt.scatter(query_point[0], query_point[1], c='red', s=100, label='Query point')
plt.scatter(points[indices, 0], points[indices, 1], c='green', s=100, label=f'Points within radius {r}')
plt.legend()
plt.title(f"Finding points within radius {r} using KDTree")
plt.circle = plt.Circle(query_point, r, color='r', fill=False)
plt.gca().add_artist(plt.circle)
plt.show()

I executed the above example code and added the screenshot below.

kdtree python

This method is useful for applications like finding all coffee shops within a 1-mile radius of your location in Manhattan.

Read Python Scipy Odeint

Method 4 – KDTree for Classification

KDTree can be used to implement a k-nearest neighbors classifier:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load Iris dataset
iris = load_iris()
X, y = iris.data, iris.target

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Build KDTree with training data
tree = KDTree(X_train)

# Predict function using KDTree
def knn_predict(X_test, k=3):
    y_pred = []
    for x in X_test:
        # Find k nearest neighbors
        distances, indices = tree.query(x, k=k)
        # Get their labels
        labels = y_train[indices]
        # Predict the most common label
        y_pred.append(np.bincount(labels).argmax())
    return np.array(y_pred)

# Make predictions
k = 3
y_pred = knn_predict(X_test, k)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy with k={k}: {accuracy:.4f}")

This example shows how to implement a KNN classifier using KDTree for classification of iris flowers, which is much faster than traditional approaches for large datasets.

Check out Python Scipy Leastsq

Method 5 – Compare Points Between Two KDTrees

You can also find the nearest neighbors between two different sets of points:

# Create two sets of points
points1 = np.random.rand(500, 2)  # Set 1 (e.g., restaurants)
points2 = np.random.rand(300, 2)  # Set 2 (e.g., coffee shops)

# Build KDTrees
tree1 = KDTree(points1)
tree2 = KDTree(points2)

# Find nearest points in set 2 for each point in set 1
distances, indices = tree1.query(points2)

# Find the closest restaurant to each coffee shop
closest_pairs = list(zip(indices, range(len(points2))))
print(f"First 5 closest pairs (restaurant, coffee shop): {closest_pairs[:5]}")

# Visualize
plt.figure(figsize=(10, 8))
plt.scatter(points1[:, 0], points1[:, 1], c='blue', alpha=0.5, label='Restaurants')
plt.scatter(points2[:, 0], points2[:, 1], c='red', alpha=0.5, label='Coffee shops')

# Draw lines between 10 closest pairs
for i in range(10):
    restaurant_idx = closest_pairs[i][0]
    coffee_idx = closest_pairs[i][1]
    plt.plot([points1[restaurant_idx, 0], points2[coffee_idx, 0]], 
             [points1[restaurant_idx, 1], points2[coffee_idx, 1]], 'k-', alpha=0.3)

plt.title("Closest restaurant to each coffee shop")
plt.legend()
plt.show()

This approach is ideal for applications like finding the nearest gas station to each hotel in a city.

Read Python Scipy Convolve 2d

Performance Comparison

To truly appreciate KDTree’s efficiency, let’s compare it with brute force search:

import time

# Generate larger dataset
large_points = np.random.rand(10000, 3)  # 10000 points in 3D
query = np.random.rand(100, 3)  # 100 query points

# Brute force approach
start_time = time.time()
brute_force_results = []
for q in query:
    distances = np.sqrt(np.sum((large_points - q)**2, axis=1))
    nearest_idx = np.argmin(distances)
    brute_force_results.append((nearest_idx, distances[nearest_idx]))
brute_force_time = time.time() - start_time

# KDTree approach
start_time = time.time()
tree = KDTree(large_points)
kdtree_results = []
for q in query:
    distance, index = tree.query(q)
    kdtree_results.append((index, distance))
kdtree_time = time.time() - start_time

print(f"Brute force time: {brute_force_time:.5f} seconds")
print(f"KDTree time: {kdtree_time:.5f} seconds")
print(f"KDTree is {brute_force_time/kdtree_time:.1f}x faster")

With large datasets like mapping all businesses in the United States, KDTree can be hundreds of times faster than brute force approaches.

I hope you found this article helpful for understanding how to use Scipy’s KDTree for efficient nearest neighbor searches. Whether you’re working on location-based services, recommendation systems, or machine learning classification problems, KDTree can significantly speed up your spatial queries and improve your application’s performance.

You may like to read:

51 Python Programs

51 PYTHON PROGRAMS PDF FREE

Download a FREE PDF (112 Pages) Containing 51 Useful Python Programs.

pyython developer roadmap

Aspiring to be a Python developer?

Download a FREE PDF on how to become a Python developer.

Let’s be friends

Be the first to know about sales and special discounts.