🤖 AI Summary
This work addresses the trade-off between expressivity and computational efficiency in generative sequence modeling, where Transformers offer strong representational capacity but suffer from high computational cost, while existing efficient linear models are limited by shallow, single-step updates. To overcome this, the authors propose PRISM, a novel architecture that, for the first time, implements a multi-step iterative optimization process in a parallel feedforward form. PRISM employs a Write-Forget decoupling mechanism, short convolution-initialized residuals, and a learnable predictor to construct a two-stage proxy framework. This design enables Rank-L cumulative updates while preserving hardware parallelism, thereby breaking the conventional Rank-1 update bottleneck. Both theoretical analysis and experiments demonstrate that PRISM matches the performance of explicit optimization methods while achieving up to a 174× throughput improvement.
📝 Abstract
Generative sequence modeling faces a fundamental tension between the expressivity of Transformers and the efficiency of linear sequence models. Existing efficient architectures are theoretically bounded by shallow, single-step linear updates, while powerful iterative methods like Test-Time Training (TTT) break hardware parallelism due to state-dependent gradients. We propose PRISM (Parallel Residual Iterative Sequence Model) to resolve this tension. PRISM introduces a solver-inspired inductive bias that captures key structural properties of multi-step refinement in a parallelizable form. We employ a Write-Forget Decoupling strategy that isolates non-linearity within the injection operator. To bypass the serial dependency of explicit solvers, PRISM utilizes a two-stage proxy architecture: a short-convolution anchors the initial residual using local history energy, while a learned predictor estimates the refinement updates directly from the input. This design distills structural patterns associated with iterative correction into a parallelizable feedforward operator. Theoretically, we prove that this formulation achieves Rank-$L$ accumulation, structurally expanding the update manifold beyond the single-step Rank-$1$ bottleneck. Empirically, it achieves comparable performance to explicit optimization methods while achieving 174x higher throughput.