The KD-Tree
A k-dimensional tree (KD-Tree) is a space-partitioning data structure for organizing points in a k-dimensional space. It is a binary search tree where every node represents a k-dimensional point.
KD-Tree Node:
1# k-Dimensional Node definition2class KDNode:3 def __init__(self, point, axis, left=None, right=None):4 self.point = point # k-dimensional coordinate (e.g. [x, y, z])5 self.axis = axis # Splitting axis index (0 for X, 1 for Y, etc.)6 self.left = left # Left subtree reference7 self.right = right # Right subtree reference
To maintain a balanced tree, we partition the dataset at the median coordinate along the splitting axis.
Median Calculation:
1# Select axis and find median point2def get_median_pivot(points, axis):3 # Sort points along the active axis4 points_sorted = sorted(points, key=lambda p: p[axis])5 median_idx = len(points_sorted) // 26 return points_sorted[median_idx]
A KD-tree is built recursively by alternating sort axis and splitting points at the median into left and right subtrees.
Tree Construction:
1# Recursive KD-Tree Construction2def build_kd_tree(points, depth=0):3 if not points:4 return None56 k = len(points[0])7 axis = depth % k89 points.sort(key=lambda x: x[axis])10 median = len(points) // 21112 return KDNode(13 point=points[median],14 axis=axis,15 left=build_kd_tree(points[:median], depth + 1),16 right=build_kd_tree(points[median + 1:], depth + 1)17 )
To perform queries, the tree is traversed from the root, selecting sub-branches based on coordinate comparisons until a leaf cell is reached.
Leaf Traversal:
1# Traverse down to leaf cell containing target point2def find_leaf_cell(node, target, depth=0):3 if node is None:4 return None56 axis = depth % len(target)7 if target[axis] < node.point[axis]:8 if node.left is None:9 return node10 return find_leaf_cell(node.left, target, depth + 1)11 else:12 if node.right is None:13 return node14 return find_leaf_cell(node.right, target, depth + 1)
Once a leaf candidate is found, the search backtracks up the tree. We check if the distance to the splitting plane is smaller than the current best distance. If not, we prune the opposite branch to avoid redundant checks.
Backtracking Pruning:
1# Backtrack and prune opposite branch if boundary is closer2dist_to_boundary = (target[axis] - node.point[axis]) ** 23if dist_to_boundary < best_dist:4 # Recurse down the opposite branch5 best_point, best_dist = find_nearest(opposite_branch, target, depth + 1, best_point, best_dist)
To find the K closest points rather than just one, we use a max-priority queue (or heap) of size K to record the closest candidates found so far.
KNN Query:
1# KNN query using a max-priority queue of size K2def knn_search(node, target, k, heap=[], depth=0):3 if node is None:4 return56 dist = sum((a - b) ** 2 for a, b in zip(node.point, target))78 # Maintain heap of size K containing (negative_dist, point)9 if len(heap) < k:10 heapq.heappush(heap, (-dist, node.point))11 elif dist < -heap[0][0]:12 heapq.heapreplace(heap, (-dist, node.point))1314 axis = depth % len(target)15 next_node = node.left if target[axis] < node.point[axis] else node.right16 other_node = node.right if target[axis] < node.point[axis] else node.left1718 # Search subtree containing target first19 knn_search(next_node, target, k, heap, depth + 1)2021 # Check other subtree if coordinate distance is closer than current K-th best distance22 dist_to_plane = (target[axis] - node.point[axis]) ** 223 if len(heap) < k or dist_to_plane < -heap[0][0]:24 knn_search(other_node, target, k, heap, depth + 1)
To retrieve all points within a specific spatial box, we prune sub-branches that do not intersect the search bounding box.
Range Search:
1# Range query finding all points within a bounding box2def range_search(node, box_min, box_max, depth=0, results=[]):3 if node is None:4 return results56 # Check if current node is inside bounding box7 if all(mn <= val <= mx for val, mn, mx in zip(node.point, box_min, box_max)):8 results.append(node.point)910 axis = depth % len(box_min)11 # Search left subtree if splitting plane is higher than box_min12 if box_min[axis] <= node.point[axis]:13 range_search(node.left, box_min, box_max, depth + 1, results)14 # Search right subtree if splitting plane is lower than box_max15 if box_max[axis] >= node.point[axis]:16 range_search(node.right, box_min, box_max, depth + 1, results)1718 return results
To find spatial overlap between two point clouds, we recursively test the bounding boxes of their KD-trees, avoiding costly all-to-all comparisons.
Tree-to-Tree Overlap:
1# KD-Tree to KD-Tree Spatial Overlap Intersection2def intersect_kd_trees(nodeA, nodeB, overlap_list):3 if nodeA is None or nodeB is None:4 return56 if not boxes_overlap(nodeA.bounds, nodeB.bounds):7 return89 if nodeA.is_leaf and nodeB.is_leaf:10 for pA in nodeA.points:11 for pB in nodeB.points:12 if distance(pA, pB) < tolerance:13 overlap_list.append((pA, pB))14 return1516 intersect_kd_trees(nodeA.left, nodeB.left, overlap_list)17 intersect_kd_trees(nodeA.left, nodeB.right, overlap_list)18 intersect_kd_trees(nodeA.right, nodeB.left, overlap_list)19 intersect_kd_trees(nodeA.right, nodeB.right, overlap_list)