🤖 AI Summary
Quadratic computational complexity of full attention severely hinders efficient long-context processing in large language models (LLMs). Existing training-free sparse attention methods suffer significant performance degradation, while native sparse approaches (e.g., NSA, MoBA) approximate full attention but are limited in modeling capacity due to insufficient sparsity. This paper proposes UniSparse, a unified training framework that enforces sparse attention to emulate full attention behavior both forward and backward via inter-layer bidirectional output alignment and multi-layer gradient preservation. We identify and rectify a critical gradient-update deficiency inherent in native sparse training for the first time, enabling dynamic sparsity adjustment at inference time. Evaluated on multiple commonsense reasoning benchmarks, UniSparse achieves state-of-the-art results, substantially improves long-context extrapolation capability, and supports flexible compute–accuracy trade-offs.
📝 Abstract
The quadratic complexity of full attention limits efficient long-context processing in large language models (LLMs). Sparse attention mitigates this cost by restricting each query to attend to a subset of previous tokens; however, training-free approaches often lead to severe performance degradation. Native sparse-attention methods (e.g., NSA, MoBA) alleviate this issue, yet exhibit a critical paradox: they produce lower attention sparsity than full-attention models, despite aiming to approximate full attention, which may constrain their effectiveness. We attribute this paradox to gradient update deficiency: low-ranked key-value pairs excluded during sparse training receive neither forward contribution nor backward gradients, and thus never learn proper suppression. To overcome this limitation, we propose SSA (Sparse Sparse Attention), a unified training framework that considers both sparse and full attention and enforces bidirectional alignment at every layer. This design preserves gradient flow to all tokens while explicitly encouraging sparse-attention outputs to align with their full-attention counterparts, thereby promoting stronger sparsity. As a result, SSA achieves state-of-the-art performance under both sparse and full attention inference across multiple commonsense benchmarks. Furthermore, SSA enables models to adapt smoothly to varying sparsity budgets; performance improves consistently as more tokens are allowed to attend, supporting flexible compute-performance trade-offs at inference time. Finally, we show that native sparse-attention training surprisingly improves long-context extrapolation by mitigating the over-allocation of attention values in sink areas, with SSA demonstrating the strongest extrapolation capability.