🤖 AI Summary
Transformer models struggle to generalize across diverse linear first-order optimization algorithms (e.g., gradient descent, conjugate gradient, momentum methods).
Method: We propose a memory-augmented Mixture-of-Experts (MoE)-Transformer framework that parameterizes optimization algorithms as learnable objects. We theoretically prove—and empirically verify—that Transformers can universally simulate the entire family of linear first-order methods. Furthermore, we introduce an MoE-driven test-time adaptation mechanism to significantly enhance out-of-distribution (OOD) generalization.
Contributions/Results: Through theoretical simulability analysis and multi-task experiments, we demonstrate that our model not only precisely reproduces classical linear first-order methods (LFOMs) but also surpasses them in convergence speed and robustness. Crucially, we establish the “algorithm-as-learnable-parameter” paradigm and validate its cross-task generalization capability—marking the first unified, learnable framework for linear optimization algorithms within the Transformer architecture.
📝 Abstract
We show that memory-augmented Transformers can implement the entire class of linear first-order methods (LFOMs), a class that contains gradient descent (GD) and more advanced methods such as conjugate gradient descent (CGD), momentum methods and all other variants that linearly combine past gradients. Building on prior work that studies how Transformers simulate GD, we provide theoretical and empirical evidence that memory-augmented Transformers can learn more advanced algorithms. We then take a first step toward turning the learned algorithms into actually usable methods by developing a mixture-of-experts (MoE) approach for test-time adaptation to out-of-distribution (OOD) samples. Lastly, we show that LFOMs can themselves be treated as learnable algorithms, whose parameters can be learned from data to attain strong performance.