🤖 AI Summary
Linear state space models (SSMs) suffer from limited performance on long-range dependency tasks (e.g., retrieval, RAG) due to their reliance on exponentially decaying history summaries. This work proposes Gated KalmaNet (GKA), the first SSM architecture incorporating test-time online ridge regression to model full-history dependencies while preserving linear time complexity and constant memory overhead. Key contributions include: (1) an input-adaptive gating mechanism that dynamically modulates the ridge penalty strength; (2) a numerically stable iterative solver based on Chebyshev polynomials, avoiding instability and parallelization bottlenecks inherent in low-precision Kalman filtering; and (3) a hardware-aware tiling implementation with custom backward kernels enabling end-to-end training. Experiments demonstrate that GKA outperforms Mamba2 and GLA on short-context benchmarks and achieves over 10% absolute improvement on 128k-token LongQA and RAG tasks relative to strong baselines.
📝 Abstract
As efficient alternatives to softmax Attention, linear state-space models (SSMs) achieve constant memory and linear compute, but maintain only a lossy, fading summary of the past, often leading to inferior performance in recall oriented tasks. We propose Gated KalmaNet (GKA), a layer that reduces this gap by accounting for the full past when predicting the next token, while maintaining SSM-style efficiency. GKA achieves this by solving an online ridge regression problem at test time, with constant memory and linear compute cost in the sequence length. Drawing inspiration from the Kalman Filter, we iteratively solve the online ridge regression problem. However, a critical insight is that standard Kalman filter equations are numerically unstable in low-precision environments (like bfloat16) and difficult to parallelize in modern hardware. We address both challenges through two key innovations: (1) an adaptive regularization strategy with input-dependent gating that controls the condition number of the ridge regression, ensuring numerical stability while balancing memory retention. And (2) the use of Chebyshev Iteration instead of other conventional iterative solvers, which we demonstrate to be more stable in low-precision settings. To further improve scalability, we develop a hardware-aware chunk-wise implementation of Chebyshev Iteration along with custom kernels for backpropagating through our adaptive regularization and gating mechanisms. Empirically, GKA shows strong language understanding capabilites on short-context tasks outperforming existing SSM layers (like Mamba2, GLA and Gated DeltaNet). On long-context, GKA excels at real-world RAG and LongQA tasks up to 128k tokens, achieving more than $10$% relative improvement over other fading memory baselines.