Diagonal Batching Unlocks Parallelism in Recurrent Memory Transformers for Long Contexts

πŸ“… 2025-06-05
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
Transformer-based long-context inference suffers from quadratic time complexity and linear memory overhead. Recurrent Memory Transformers (RMTs) reduce complexity to linear time and constant memory but remain bottlenecked by their serial memory update mechanism. This work introduces Diagonal Batchingβ€”a scheduling strategy that enables cross-segment parallel computation while strictly preserving recurrent dependencies. It achieves the first lossless, training-free, and architecture-transparent parallelization of RMTs via runtime computational graph rescheduling, fully compatible with standard Transformer inference frameworks. Experiments demonstrate a 3.3Γ— speedup over the LLaMA-1B baseline and a 1.8Γ— acceleration over the original serial RMT implementation on 131,072-token sequences, significantly reducing inference latency and hardware cost for long-context reasoning.

Technology Category

Application Category

πŸ“ Abstract
Transformer models struggle with long-context inference due to their quadratic time and linear memory complexity. Recurrent Memory Transformers (RMTs) offer a solution by reducing the asymptotic cost to linear time and constant memory usage. However, their memory update mechanism leads to sequential execution, causing a performance bottleneck. We introduce Diagonal Batching, a scheduling scheme that unlocks parallelism across segments in RMTs while preserving exact recurrence. This approach eliminates the sequential constraint, enabling efficient GPU inference even for single long-context inputs without complex batching and pipelining techniques. Because the technique is purely a run-time computation reordering, existing RMT models adopt it with no retraining. Applied to a LLaMA-1B ARMT model, Diagonal Batching yields a 3.3x speedup over standard full-attention LLaMA-1B and a 1.8x speedup over the sequential RMT implementation on 131,072-token sequences. By removing sequential bottleneck, Diagonal Batching reduces inference cost and latency, thereby strengthening RMTs as a practical solution for real-world, long-context applications.
Problem

Research questions and friction points this paper is trying to address.

Transformers face quadratic time complexity in long-context inference
RMTs reduce cost but suffer from sequential execution bottlenecks
Diagonal Batching enables parallel processing in RMTs without retraining
Innovation

Methods, ideas, or system contributions that make the work stand out.

Diagonal Batching enables parallel segment processing
Preserves exact recurrence without retraining
Reduces inference cost and latency significantly