🤖 AI Summary
This work addresses the challenges of rapid information decay and poor convergence in naive stochastic gradient descent (SGD) updates within linear attention mechanisms, as well as the inefficacy of existing momentum methods in balancing efficiency and performance. To overcome these limitations, the authors propose Momentum DeltaNet (MDN), which for the first time integrates momentum into linear attention in a stable and parallelizable manner. MDN achieves this by geometrically reparameterizing update coefficients, designing a block-wise parallel algorithm to implement incremental momentum rules, and modeling the momentum recurrence from a dynamical systems perspective as a second-order system with conjugate complex eigenvalues, augmented with stability-preserving gating constraints. Experiments on 400M and 1.3B parameter models demonstrate that MDN consistently outperforms strong baselines—including Transformers, Mamba2, and GDN—across multiple downstream tasks while maintaining training throughput comparable to state-of-the-art linear models.
📝 Abstract
Linear Attention (LA) offers a promising paradigm for scaling large language models (LLMs) to long sequences by avoiding the quadratic complexity of self-attention. Recent LA models such as Mamba2 and GDN interpret linear recurrences as closed-form online stochastic gradient descent (SGD), but naive SGD updates suffer from rapid information decay and suboptimal convergence in optimization. While momentum-based optimizers provide a natural remedy, they pose challenges in simultaneously achieving training efficiency and effectiveness. To address this, we develop a chunkwise parallel algorithm for LA with a stepwise momentum rule by geometrically reordering the update coefficients. Further, from a dynamical systems perspective, we analyze the momentum-based recurrence as a second-order system that introduces complex conjugate eigenvalues. This analysis guides the design of stable gating constraints. The resulting model, Momentum DeltaNet (MDN), leverages Triton kernels to achieve comparable training throughput with competitive linear models such as Mamba2 and KDA. Extensive experiments on the 400M and 1.3B parameter models demonstrate consistent performance improvements over strong baselines, including Transformers, Mamba2 and GDN, across diverse downstream evaluation benchmarks. Code: https://github.com/HuuYuLong/MomentumDeltaNet .