🤖 AI Summary
Recurrent Neural Networks (RNNs) suffer from high computational overhead during training due to backpropagation through time (BPTT), limiting scalability. This work proposes SS-RNN, a novel RNN architecture that eliminates BPTT by integrating principles from state space models (SSMs). Leveraging the assumption of temporal stationarity, SS-RNN employs a static, structured state feedback matrix, reformulating gradient propagation as a forward-mode linear transformation. This constitutes the first incorporation of SSM mechanisms into the RNN gradient computation paradigm. Crucially, SS-RNN preserves effective long-range dependency modeling while enabling efficient, approximate gradient estimation. Empirically, on standard language modeling benchmarks, SS-RNN achieves perplexity competitive with Transformers at comparable parameter counts, while reducing training cost significantly and accelerating inference.
📝 Abstract
Recurrent neural networks (RNNs) have recently demonstrated strong performance and faster inference than Transformers at comparable parameter budgets. However, the recursive gradient computation with the backpropagation through time (or BPTT) algorithm remains the major computational bottleneck. In this work, we propose a novel method that replaces BPTT with a fixed gradient feedback mechanism, yielding an efficient approximation of the exact gradient propagation based on the assumption of time stationarity. Our approach leverages state-space model (SSM) principles to define a structured feedback matrix that directly propagates gradients from future time steps. This formulation bypasses the need for recursive gradient backpropagation, significantly reducing training overhead while preserving the network's ability to capture long-term dependencies. The experiments on language modeling benchmarks exhibit competitive perplexity scores, while significantly reducing the training costs. These promising results suggest that designing a feedback method like an SSM can fully exploit the efficiency advantages of RNNs for many practical applications.