π€ AI Summary
This work addresses the high computational cost of the KL-Shampoo optimizer in large language model pretraining by uncovering, for the first time, that its Kronecker factors exhibit a βspiky-flatβ spectral structure. Building on this insight, the authors propose Pro-KLShampoo, which projects dominant directions of the Kronecker factors onto a low-dimensional subspace to preserve full spectral information, while approximating the remaining directions using shared eigenvalues and incorporating gradient momentum orthogonalization. This approach maintains algebraic equivalence to the original optimizer while substantially reducing both computational and memory overhead. Experiments across four model scales of GPT-2 and LLaMA demonstrate that Pro-KLShampoo consistently outperforms the original KL-Shampoo in terms of validation loss, peak memory usage per GPU, and convergence speed.
π Abstract
Optimizers that exploit the matrix structure of gradients are central to modern LLM pre-training, with two distinct frontiers: explicit Kronecker-factored preconditioning -- most recently KL-Shampoo, which estimates the preconditioner via KL divergence minimization -- and orthogonalization of the gradient momentum, exemplified by Muon and analyzed as steepest descent under the spectral norm. The two routes are typically developed in isolation. We make a structural observation about KL-Shampoo's Kronecker preconditioners: their eigenvalue spectra exhibit a \emph{spike-and-flat} shape -- a few dominant eigenvalues followed by an approximately uniform tail -- across layers and training stages, holding exactly under a rank-$Ο$ signal-plus-noise gradient model. We exploit this structure by restricting one of KL-Shampoo's Kronecker factors to a parametric family aligned with the spike-and-flat shape: full spectral structure on a tracked $r$-dimensional subspace, single shared eigenvalue across the remaining $n-r$ directions. On these directions, we apply orthogonalization. An identity shows that this orthogonalization recovers the algebraic form of full KL-Shampoo's preconditioner. On four pre-training scales (GPT-2 124M / 350M, LLaMA 134M / 450M), Pro-KLShampoo consistently outperforms KL-Shampoo at every subspace rank we test in validation loss, peak per-GPU memory, and wallclock time to reach each loss level.