HashAttention: Semantic Sparsity for Faster Inference

πŸ“… 2024-12-19
πŸ›οΈ arXiv.org
πŸ“ˆ Citations: 5
✨ Influential: 1
πŸ“„ PDF
πŸ€– AI Summary
To address the high computational cost of attention over long contexts, this paper proposes a semantic hashing–based sparse attention mechanism. It formulates key-token identification as a learnable hash recommendation task in Hamming space, mapping keys and queries to binary hash codes via differentiable projections and leveraging bitwise operations for efficient retrieval of highly relevant tokens. This work is the first to deeply integrate semantic similarity modeling with hash-based retrieval, enabling large-scale token pruning without significant quality degradation. Evaluated on Llama-3.1-8B with LongBench, it achieves a 32Γ— sparsity ratio with only a 0.6-point average performance drop. Inference speed improves by 3–6Γ— over LightLLM and 2.5–4.5Γ— over gpt-fast on an NVIDIA L4 GPU, while maintaining balanced efficiency, accuracy, and memory footprint (32 bits/token).

Technology Category

Application Category

πŸ“ Abstract
Utilizing longer contexts is increasingly essential to power better AI systems. However, the cost of attending to long contexts is high due to the involved softmax computation. While the scaled dot-product attention (SDPA) exhibits token sparsity, with only a few pivotal tokens significantly contributing to attention, leveraging this sparsity effectively remains an open challenge. Previous methods either suffer from model degradation or require considerable additional resources. We propose HashAttention --a principled approach casting pivotal token identification as a recommendation problem. Given a query, HashAttention encodes keys and queries in Hamming space capturing the required semantic similarity using learned mapping functions. HashAttention efficiently identifies pivotal tokens for a given query in this Hamming space using bitwise operations, and only these pivotal tokens are used for attention computation, significantly improving overall attention efficiency. HashAttention can reduce the number of tokens used by a factor of $1/32 imes$ for the Llama-3.1-8B model with LongBench, keeping average quality loss within 0.6 points, while using only 32 bits per token auxiliary memory. At $32 imes$ sparsity, HashAttention is $3{-}6 imes$ faster than LightLLM and $2.5{-}4.5 imes$ faster than gpt-fast on Nvidia-L4 GPU.
Problem

Research questions and friction points this paper is trying to address.

Scalability challenge in attention computation for long contexts
Difficulty in exploiting token sparsity without quality degradation
Inefficiency of existing MIPS solutions for SDPA on GPUs
Innovation

Methods, ideas, or system contributions that make the work stand out.

Encodes keys and queries in Hamming space
Identifies pivotal tokens using bitwise operations
Reduces attention latency and increases throughput