π€ AI Summary
Transformers face dual challenges in long-context modeling: poor length generalization and high computational complexity of self-attention. This paper proposes Grouped Cross-Attention (GCA), a dynamic context grouping mechanism enabled by an end-to-end learnable causal retriever that efficiently captures long-range dependencies within a fixed-size attention window. The retriever is trained directly via autoregressive lossβwithout window expansion or post-training length extrapolation. Leveraging chunked input processing and top-k causal retrieval, GCA achieves zero-shot generalization from training context length (16K) to inference length (16M)βa 1000Γ scale-up. On the 16M-context passkey retrieval task, it attains near-perfect accuracy while substantially reducing memory and computational overhead during both training and inference. The core contribution is the first joint optimization of learnable causal retrieval and grouped cross-attention, breaking away from conventional length extrapolation paradigms.
π Abstract
Despite the success of Transformers, handling long contexts remains challenging due to the limited length generalization and quadratic complexity of self-attention. Thus Transformers often require post-training with a larger attention window, significantly increasing computational and memory costs. In this paper, we propose a novel attention mechanism based on dynamic context, Grouped Cross Attention (GCA), which can generalize to 1000 times the pre-training context length while maintaining the ability to access distant information with a constant attention window size. For a given input sequence, we split it into chunks and use each chunk to retrieve top-k relevant past chunks for subsequent text generation. Specifically, unlike most previous works that use an off-the-shelf retriever, our key innovation allows the retriever to learn how to retrieve past chunks that better minimize the auto-regressive loss of subsequent tokens in an end-to-end manner. Such a mechanism accommodates retrieved chunks with a fixed-size attention window to achieve long-range information access, significantly reducing computational and memory costs during training and inference. Experiments show that GCA-based models achieve near-perfect accuracy in passkey retrieval for 16M context lengths, which is 1000 times the training length.