🤖 AI Summary
Recurrent Neural Networks (RNNs) suffer from limited random access to long-sequence history, while Transformers exhibit quadratic computational complexity and poor length generalization. Method: This paper proposes Hierarchical Sparse Attention (HSA), a fine-grained, hardware-aligned sparse attention mechanism driven by token-level semantic similarity for dynamic chunk retrieval. It introduces the first token-to-chunk relevance modeling framework to enable cross-length generalization and integrates Mamba’s state-space modeling with custom CUDA/TPU sparse kernels. Contribution/Results: HSA achieves linear O(L) time and memory complexity. On the 64M-length passkey retrieval task, it attains 100% accuracy—despite training only on 4K-context sequences—and delivers substantial average gains on downstream benchmarks. Memory consumption remains nearly constant with sequence length, and inference is several times faster than comparable-length Transformers.
📝 Abstract
A key advantage of Recurrent Neural Networks (RNNs) over Transformers is their linear computational and space complexity enables faster training and inference for long sequences. However, RNNs are fundamentally unable to randomly access historical context, and simply integrating attention mechanisms may undermine their efficiency advantages. To overcome this limitation, we propose extbf{H}ierarchical extbf{S}parse extbf{A}ttention (HSA), a novel attention mechanism that enhances RNNs with long-range random access flexibility while preserving their merits in efficiency and length generalization. HSA divides inputs into chunks, selecting the top-$k$ chunks and hierarchically aggregates information. The core innovation lies in learning token-to-chunk relevance based on fine-grained token-level information inside each chunk. This approach enhances the precision of chunk selection across both in-domain and out-of-domain context lengths. To make HSA efficient, we further introduce a hardware-aligned kernel design. By combining HSA with Mamba, we introduce RAMba, which achieves perfect accuracy in passkey retrieval across 64 million contexts despite pre-training on only 4K-length contexts, and significant improvements on various downstream tasks, with nearly constant memory footprint. These results show RAMba's huge potential in long-context modeling.