🤖 AI Summary
This work addresses the limitations of existing hierarchical attention mechanisms, which rely on a fixed number of top-k key-value blocks, thereby impeding gradient flow and lacking adaptability to query relevance. The authors propose an end-to-end differentiable adaptive sparse hierarchical attention mechanism that dynamically selects variable-length key-value blocks using α-entmax and provides a structured prior for the second-stage softmax attention. This approach achieves fully differentiable, non-diffuse adaptive sparsity for the first time, matching the accuracy of full attention at 75% sparsity while outperforming NSA and InfLLMv2 in Pareto efficiency. Furthermore, with Triton-based optimizations, it attains inference speeds surpassing FlashAttention-3.
📝 Abstract
Current hierarchical attention methods, such as NSA and InfLLMv2, select the top-k relevant key-value (KV) blocks based on coarse attention scores and subsequently apply fine-grained softmax attention on the selected tokens. However, the top-k operation assumes the number of relevant tokens for any query is fixed and it precludes the gradient flow between the sparse and dense stages. In this work, we propose DashAttention (Differentiable and Adaptive Sparse Hierarchical Attention), which leverages the adaptively sparse $α$-entmax transformation to select a variable number of blocks according to the current query in the first stage. This in turn provides a prior for the second-stage softmax attention, keeping the entire hierarchy fully differentiable. Contrary to other hierarchical attention methods, we show that DashAttention is non-dispersive, translating to better long-context modeling ability. Experiments with large language models (LLMs) show that DashAttention achieves comparable accuracy as full attention with 75% sparsity and a better Pareto frontier than NSA and InfLLMv2, especially in high-sparsity regimes. We also provide an efficient, GPU-aware implementation of DashAttention in Triton, which achieves a speedup of up to over FlashAttention-3 at inference time. Overall, DashAttention offers a cost-effective strategy to model long contexts.