🤖 AI Summary
This work investigates whether low-complexity models can surpass comparably sized Transformers in mathematical reasoning under fixed inference budgets, leveraging higher throughput. We propose a knowledge distillation–based optimization framework for the Mamba architecture: using a Transformer as teacher, we perform lightweight distillation (8B tokens) to train pure and hybrid Mamba variants, augmented with chain-of-thought trajectory generation and aggregation, and domain-specific fine-tuning on mathematical reasoning tasks. Experiments show that distilled Mamba achieves significantly faster inference—especially for long sequences and large batch sizes—while suffering only marginal zero-shot accuracy degradation, recoverable via increased inference steps. Notably, on benchmarks including MATH and AMC, it achieves both higher accuracy and broader problem coverage within time-constrained settings for the first time, establishing a novel paradigm of “inference compute scalability.”
📝 Abstract
Recent advancements have demonstrated that the performance of large language models (LLMs) can be significantly enhanced by scaling computational resources at test time. A common strategy involves generating multiple Chain-of-Thought (CoT) trajectories and aggregating their outputs through various selection mechanisms. This raises a fundamental question: can models with lower complexity leverage their superior generation throughput to outperform similarly sized Transformers for a fixed computational budget? To address this question and overcome the lack of strong subquadratic reasoners, we distill pure and hybrid Mamba models from pretrained Transformers. Trained on only 8 billion tokens, our distilled models show strong performance and scaling on mathematical reasoning datasets while being much faster at inference for large batches and long sequences. Despite the zero-shot performance hit due to distillation, both pure and hybrid Mamba models can scale their coverage and accuracy performance past their Transformer teacher models under fixed time budgets, opening a new direction for scaling inference compute.