🤖 AI Summary
This work addresses the inefficiency of traditional space-tree algorithms on GPUs, which suffer from thread divergence and irregular memory access that hinder parallel performance. To overcome these limitations, the authors propose a GPU-optimized Morton-ordered flat tree structure that leverages a flattened memory layout and z-order encoding to substantially improve memory coalescing and thread cooperation during dual-tree traversal. Implemented in JAX and CUDA, the approach enables highly efficient k-nearest neighbor search and Friends-of-Friends clustering, achieving over an order-of-magnitude speedup compared to existing GPU libraries on datasets with $N \gtrsim 10^7$ points, while also supporting strong multi-GPU distributed scaling.
📝 Abstract
Algorithms based on spatial tree traversal are widely regarded as among the most efficient and flexible approaches for many problems in CPU-based high-performance computing (HPC). However, directly transferring these algorithms to GPU architectures often yields substantially smaller performance gains than expected in light of the high computational throughput of modern GPUs. The branching nature of tree algorithms leads to thread divergence and irregular memory access patterns -- both of which may severely limit GPU performance. To address these challenges, we propose a Morton (z-order) 'plane-based tree hierarchy' that is specifically designed for GPU architectures. The resulting flattened data layout enables efficient dual-tree traversal with collaborative execution across thread groups, leading to highly coalesced memory access patterns. Based on this framework we present implementations of two important spatial algorithms -- exact $k$-nearest neighbour search and friends-of-friends (FoF) clustering. For both cases, we observe more than an order-of-magnitude performance improvement over the closest competing GPU libraries for large problem sizes ($N \gtrsim 10^7$), together with strong scaling to distributed multi-GPU systems. We provide an open-source implementation, 'JZ-Tree' (JAX z-order tree), which serves as a foundation for efficient GPU implementations of a broad class of tree-based algorithms.