🤖 AI Summary
This work addresses how to efficiently incorporate cross-positional recurrent memory into autoregressive Transformers while preserving their architecture and key-value (KV) caching interface. The authors propose the Lightweight Recurrent Transformer (LRT), which reuses the high-level hidden state of the preceding token as recurrent memory for the next token, establishing a cross-layer hidden-state pathway with only ~0.3% additional parameters. They further introduce an interleaved parallel training strategy that enables full-sequence recurrent-aware supervision at approximately twice the computational cost. Experiments demonstrate that, across various NanoChat backbones and token-per-parameter budgets, LRT consistently reduces language modeling loss and enhances in-context learning performance under matched computational budgets.
📝 Abstract
We study Latent Recurrent Transformer (LRT), a lightweight augmentation of autoregressive transformers that reuses a high-level source-layer hidden state from the previous token as recurrent memory for the next token. Because this source state is already computed during ordinary decoding, LRT adds a cross-layer recurrent latent pathway across positions without inserting pause tokens or extra depth loops, and the standard attention mechanism and KV-cache interface are preserved. To pretrain this recurrence at scale without sequentially unrolling the transformer, we introduce interleaved parallel training: a single full-sequence initialization forward pass builds a shared buffer; then disjoint position subsets are refined in parallel and written back, so that all tokens receive recurrent-memory-aware supervision at roughly 2 times baseline compute. Across nanochat style backbones and a wide range of tokens-per-parameter budgets, LRT improves both language-modeling loss and in-context learning under matched effective compute while adding as little as 0.3% parameters.