🤖 AI Summary
This work addresses Top-K selection—a critical performance bottleneck in machine learning workloads on accelerators—by proposing a generalized two-stage approximate algorithm: (1) partition the input array into blocks and extract top-$K'$ candidates from each block ($K' geq 1$), followed by (2) exact sorting and selection of the global top-$K$ elements from the reduced candidate set. We derive, for the first time, a closed-form expression for the expected recall rate and theoretically prove that the algorithm achieves a recall bound twice as tight as prior methods. We further show that $K' > 1$ significantly reduces the second-stage input size while preserving high recall. Integrating probabilistic analysis, block-wise sampling, and hardware-aware optimization for Cloud TPUv5e, our implementation delivers up to 10× speedup over baseline approaches, while maintaining 100% recall on real-world tasks.
📝 Abstract
We consider the Top-$K$ selection problem, which aims to identify the largest-$K$ elements from an array. Top-$K$ selection arises in many machine learning algorithms and often becomes a bottleneck on accelerators, which are optimized for dense matrix multiplications. To address this problem, citet{chern2022tpuknnknearestneighbor} proposed a fast two-stage extit{approximate} Top-$K$ algorithm: (i) partition the input array and select the top-$1$ element from each partition, (ii) sort this extit{smaller subset} and return the top $K$ elements. In this paper, we consider a generalized version of this algorithm, where the first stage selects top-$K'$ elements, for some $1 leq K' leq K$, from each partition. Our contributions are as follows: (i) we derive an expression for the expected recall of this generalized algorithm and show that choosing $K'>1$ with fewer partitions in the first stage reduces the input size to the second stage more effectively while maintaining the same expected recall as the original algorithm, (ii) we derive a bound on the expected recall for the original algorithm in citet{chern2022tpuknnknearestneighbor} that is provably tighter by a factor of $2$ than the one in that paper, and (iii) we implement our algorithm on Cloud TPUv5e and achieve around an order of magnitude speedups over the original algorithm without sacrificing recall on real-world tasks.