🤖 AI Summary
This work addresses the training–inference discrepancy in long-context large language models (LLMs), which typically employ full-context attention during training but segmented execution during inference, thereby limiting performance and scalability. The paper proposes the first framework that ensures semantic consistency between training and inference by adopting an identical segment-wise forward mechanism in both phases. Specifically, gradients are restricted to propagate only to the key-value (KV) states of the preceding segment, and a head-specific historical KV access strategy is introduced. This approach achieves, for the first time, consistency in both execution semantics and state transitions for long-context LLMs. Evaluated on 128K-context tasks, it reduces peak memory during prefilling by approximately sixfold while matching the performance of full attention on long-context benchmarks, significantly outperforming existing efficient inference methods.
📝 Abstract
Transformer-based large language models face severe scalability challenges in long-context generation due to the computational and memory costs of full-context attention. Under practical computation and memory constraints, many inference-efficient long-context methods improve efficiency by adopting bounded-context or segment-level execution only during inference, while continuing to train models under full-context attention, resulting in a mismatch between training and inference execution and state-transition semantics. Based on this insight, we propose a training-inference consistent segment-level generation framework, in which training and inference follow the same segment-level forward execution semantics. During training, consistency with inference is enforced by restricting gradient propagation to KV states carried over from the immediately preceding segment, while permitting head-specific access to past KV states during the forward pass without involving them in gradient propagation. Across long-context benchmarks, our approach achieves performance comparable to full-context attention, while achieving competitive latency-memory trade-offs against strong inference-efficient baselines, and substantially improving scalability at very long context lengths (e.g., approximately 6x lower peak prefill memory at 128K compared to full-context attention with FlashAttention).