🤖 AI Summary
Recursive Transformers suffer from computational homogenization and entanglement of long-term and transient information due to recurrent hidden-state reuse, limiting their performance under parameter constraints. To address this, we propose Memory-as-State-Highways (MeSH): an architecture that externalizes state management into an explicit memory buffer, employs a lightweight dynamic router to enable functional specialization across iterations, and incorporates hidden-state probing to optimize information flow. This design decouples computational depth from parameter count. Evaluated on Pythia-1.4B, MeSH achieves superior performance using only 67% of the non-embedding parameters of larger non-recursive baselines, yielding an average +1.06% accuracy gain across downstream tasks. Notably, it is the first approach to significantly narrow the performance gap between recursive and non-recursive models while preserving the recursive structure.
📝 Abstract
Recursive transformers reuse parameters and iterate over hidden states multiple times, decoupling compute depth from parameter depth. However, under matched compute, recursive models with fewer parameters often lag behind non-recursive counterparts. By probing hidden states, we trace this performance gap to two primary bottlenecks: undifferentiated computation, where the core is forced to adopt a similar computational pattern at every iteration, and information overload, where long-lived and transient information must coexist in a single hidden state. To address the issues, we introduce a Memory-as-State-Highways (MeSH) scheme, which externalizes state management into an explicit memory buffer and employs lightweight routers to dynamically diversify computation across iterations. Probing visualizations confirm that MeSH successfully resolves the pathologies by inducing functional specialization across iterations. On the Pythia suite (160M-1.4B), MeSH-enhanced recursive transformers consistently improve over recursive baselines and outperforms its larger non-recursive counterpart at the 1.4B scale, improving average downstream accuracy by +1.06% with 33% fewer non-embedding parameters. Our analysis establishes MeSH as a scalable and principled architecture for building stronger recursive models.