🤖 AI Summary
While structured sparse attention can enhance inference efficiency, directly sparsifying dense models often leads to significant performance degradation. This work proposes the RAT+ architecture, which incorporates full-sequence recurrence and active recurrent learning during dense pretraining, enabling a single model to seamlessly switch to various sparse attention patterns—such as dilated or top-k block—at inference time without retraining. This approach is the first to unify dense training with sparse inference within a single model, maintaining high performance after only a brief resolution adaptation phase. Experiments show that a 1.5B-parameter model incurs accuracy drops of merely 0–2 points on Commonsense Reasoning and 2–3 points on LongBench under sparse inference, while a 2.6B-parameter model further demonstrates the method’s scalability.
📝 Abstract
Structured dilated attention has an appealing inference-time efficiency knob: it reduces the FLOPs of the attention and the KV cache size by a factor of the dilation size D, while preserving long-range connectivity. However, we find a persistent failure mode of them -- sparsifying a pretrained attention model to a dilated pattern leads to severe accuracy degradation. We introduce RAT+, a dense-pretraining architecture that augments attention with full-sequence recurrence and active recurrence learning. A single RAT+ model is pretrained densely once, then flexibly switched at inference time to dilated attention (optionally with local windows) or hybrid layer/head compositions, requiring only a short 1B-token resolution adaptation rather than retraining separate sparse models. At 1.5B parameters trained on 100B tokens, RAT+ closely matches dense accuracy at 16 and drops by about 2-3 points at 64 on commonsense reasoning and LongBench tasks, respectively. Moreover, RAT+ outperforms attention when sparsifying to the top-k block attention. We further scale to 2.6B parameters and 200B tokens and observe the same trend.