Efficient Length-Generalizable Attention via Causal Retrieval for Long-Context Language Modeling

πŸ“… 2024-10-02
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
Transformers face dual challenges in long-context modeling: poor length generalization and high computational complexity of self-attention. This paper proposes Grouped Cross-Attention (GCA), a dynamic context grouping mechanism enabled by an end-to-end learnable causal retriever that efficiently captures long-range dependencies within a fixed-size attention window. The retriever is trained directly via autoregressive lossβ€”without window expansion or post-training length extrapolation. Leveraging chunked input processing and top-k causal retrieval, GCA achieves zero-shot generalization from training context length (16K) to inference length (16M)β€”a 1000Γ— scale-up. On the 16M-context passkey retrieval task, it attains near-perfect accuracy while substantially reducing memory and computational overhead during both training and inference. The core contribution is the first joint optimization of learnable causal retrieval and grouped cross-attention, breaking away from conventional length extrapolation paradigms.

Technology Category

Application Category

πŸ“ Abstract
Despite the success of Transformers, handling long contexts remains challenging due to the limited length generalization and quadratic complexity of self-attention. Thus Transformers often require post-training with a larger attention window, significantly increasing computational and memory costs. In this paper, we propose a novel attention mechanism based on dynamic context, Grouped Cross Attention (GCA), which can generalize to 1000 times the pre-training context length while maintaining the ability to access distant information with a constant attention window size. For a given input sequence, we split it into chunks and use each chunk to retrieve top-k relevant past chunks for subsequent text generation. Specifically, unlike most previous works that use an off-the-shelf retriever, our key innovation allows the retriever to learn how to retrieve past chunks that better minimize the auto-regressive loss of subsequent tokens in an end-to-end manner. Such a mechanism accommodates retrieved chunks with a fixed-size attention window to achieve long-range information access, significantly reducing computational and memory costs during training and inference. Experiments show that GCA-based models achieve near-perfect accuracy in passkey retrieval for 16M context lengths, which is 1000 times the training length.
Problem

Research questions and friction points this paper is trying to address.

Long Sequence
Transformer Models
Computational Complexity
Innovation

Methods, ideas, or system contributions that make the work stand out.

Grouped Cross Attention
Efficient Long Sequence Processing
Fixed-size Window Approach
πŸ”Ž Similar Papers
No similar papers found.