π€ AI Summary
This work addresses the challenge that existing parallelizable sequence models, such as Transformers and state space models, struggle to simultaneously achieve efficient training and strong long-term memory retention. To overcome this limitation, the authors propose the Memory Recurrent Unit (MRU), which introduces multistable dynamics into a parallelizable RNN architecture for the first time. By eliminating transient responses, MRU enables persistent memory storage while leveraging parallel scan algorithms to ensure training efficiency. An instantiation of this framework, termed BMRU, demonstrates superior performance on long-range dependency tasks. Furthermore, BMRU can be effectively integrated with state space models to yield a hybrid architecture that combines both transient and steady-state memory capabilities, achieving high efficiency without compromising representational power.
π Abstract
With the emergence of massively parallel processing units, parallelization has become a desirable property for new sequence models. The ability to parallelize the processing of sequences with respect to the sequence length during training is one of the main factors behind the uprising of the Transformer architecture. However, Transformers lack efficiency at sequence generation, as they need to reprocess all past timesteps at every generation step. Recently, state-space models (SSMs) emerged as a more efficient alternative. These new kinds of recurrent neural networks (RNNs) keep the efficient update of the RNNs while gaining parallelization by getting rid of nonlinear dynamics (or recurrence). SSMs can reach state-of-the art performance through the efficient training of potentially very large networks, but still suffer from limited representation capabilities. In particular, SSMs cannot exhibit persistent memory, or the capacity of retaining information for an infinite duration, because of their monostability. In this paper, we introduce a new family of RNNs, the memory recurrent units (MRUs), that combine the persistent memory capabilities of nonlinear RNNs with the parallelizable computations of SSMs. These units leverage multistability as a source of persistent memory, while getting rid of transient dynamics for efficient computations. We then derive a specific implementation as proof-of-concept: the bistable memory recurrent unit (BMRU). This new RNN is compatible with the parallel scan algorithm. We show that BMRU achieves good results in tasks with long-term dependencies, and can be combined with state-space models to create hybrid networks that are parallelizable and have transient dynamics as well as persistent memory.