π€ AI Summary
This work addresses the limitations of conventional recurrent Transformers, which struggle to simultaneously excel at mathematical reasoning and commonsense tasks due to constrained parameter capacity. The authors propose a novel architecture that integrates an adaptive layer-wise recurrence mechanism, a gated external memory bank, and a learnable halting strategy to dynamically regulate the number of iterations per layer and enhance representational power. For the first time, this approach jointly incorporates adaptive recurrence and learnable external memory, achieving substantially superior performance over a baseline model three times deeperβwhile maintaining identical computational cost (FLOPs). The method demonstrates significant gains on mathematical reasoning benchmarks and effectively recovers performance on commonsense tasks. Additionally, the study uncovers emergent inter-layer functional specialization within the model, offering new insights for designing efficient reasoning architectures.
π Abstract
Chain-of-thought (CoT) prompting enables reasoning in language models but requires explicit verbalization of intermediate steps. Looped transformers offer an alternative by iteratively refining representations within hidden states. This parameter efficiency comes at a cost, as looped models lack the storage capacity of deeper models which use unique weights per layer. In this work, we investigate transformer models that feature both adaptive per-layer looping, where each transformer block learns to iterate its hidden state via a learned halting mechanism, and gated memory banks, that provide additional learned storage. We find that looping primarily benefits mathematical reasoning, while memory banks help recover performance on commonsense tasks compared to parameter and FLOP matched models. Combining both mechanisms yields a model that outperforms an iso-FLOP baseline -- with three times the number of layers -- on math benchmarks. Analysis of model internals reveals layer specialization: early layers learn to loop minimally and access memory sparingly, while later layers do both more heavily.