🤖 AI Summary
To address the prohibitively large KV cache size and low retrieval efficiency in large language model (LLM) inference, this paper proposes a nonlinear hashing-based KV cache compression and acceleration method. Our approach introduces three key contributions: (1) a learnable nonlinear hash function that significantly improves the discriminability of key-query embeddings, enabling ≥5× reduction in hash code length while preserving or enhancing retrieval accuracy; (2) a lightweight training framework incorporating Bradley–Terry pairwise ranking loss for efficient, low-resource optimization; and (3) GPU-native CUDA kernels leveraging bit-level operations and dynamic KV candidate pruning. Evaluated on a single A100 GPU, our method achieves ≤100 μs latency for 512K-token retrieval, triples end-to-end decoding throughput over baseline, drastically reduces memory footprint, and maintains original generation quality without degradation.
📝 Abstract
Reducing the key-value (KV) cache burden in Large Language Models (LLMs) significantly accelerates inference. Dynamically selecting critical KV caches during decoding helps maintain performance. Existing methods use random linear hashing to identify important tokens, but this approach is inefficient due to the orthogonal distribution of queries and keys within two narrow cones in LLMs. We introduce Spotlight Attention, a novel method that employs non-linear hashing functions to optimize the embedding distribution of queries and keys, enhancing coding efficiency and robustness. We also developed a lightweight, stable training framework using a Bradley-Terry ranking-based loss, enabling optimization of the non-linear hashing module on GPUs with 16GB memory in 8 hours. Experimental results show that Spotlight Attention drastically improves retrieval precision while shortening the length of the hash code at least 5$ imes$ compared to traditional linear hashing. Finally, we exploit the computational advantages of bitwise operations by implementing specialized CUDA kernels, achieving hashing retrieval for 512K tokens in under 100$μ$s on a single A100 GPU, with end-to-end throughput up to 3$ imes$ higher than vanilla decoding.