🤖 AI Summary
This work addresses the challenge that large language models struggle with long-range reasoning and retrieval when operating beyond their pretrained context length. To overcome this limitation, the authors propose the L2A mechanism, which introduces token-level conditional attention for the first time: at each transformer layer, the model dynamically decides whether to enable global attention for each token, thereby accessing long-range memory in a sparse manner. By integrating custom Triton GPU kernels with KV cache pruning, the method extends the effective context length of the Qwen model from 32K to 128K tokens. Notably, 80% of tokens bypass global attention computation, resulting in approximately 2× higher training throughput, a 50% reduction in KV cache memory usage, and less than 3% performance degradation.
📝 Abstract
Language models struggle to generalize beyond pretraining context lengths, limiting long-horizon reasoning and retrieval. Continued pretraining on long-context data can help but is expensive due to the quadratic scaling of Attention. We observe that most tokens do not require (Global) Attention over the entire sequence and can rely on local context. Based on this, we propose L2A (Learning To Attend), a layer that enables conditional (token-wise) long-range memory access by deciding when to invoke global attention. We evaluate L2A on Qwen 2.5 and Qwen 3 models, extending their effective context length from 32K to 128K tokens. L2A matches the performance of standard long-context training to within 3% while skipping Global Attention for $\sim$80% of tokens, outperforming prior baselines. We also design custom Triton kernels to efficiently implement this token-wise conditional Attention on GPUs, achieving up to $\sim$2x improvements in training throughput and time-to-first-token over FlashAttention. Moreover, L2A enables post-training pruning of highly sparse Global Attention layers, reducing KV cache memory by up to 50% with negligible performance loss.