🤖 AI Summary
This work addresses the challenge of scaling the Shampoo optimizer to large models due to its prohibitive computational overhead. The authors propose an efficient distributed implementation that stacks preconditioning blocks into a 3D tensor to enhance GPU utilization and accelerates the computation of matrix inverse square roots by combining Newton–Schulz iteration with Chebyshev polynomial approximation. Furthermore, they provide the first systematic analysis of how matrix scaling affects convergence behavior. Experimental results demonstrate that the proposed method achieves up to a 4.83× speedup in optimizer step time while maintaining optimal validation perplexity, with the Newton–Schulz variant yielding the best performance.
📝 Abstract
Shampoo is one of the leading approximate second-order optimizers: a variant of it has won the MLCommons AlgoPerf competition, and it has been shown to produce models with lower activation outliers that are easier to compress. Yet, applying Shampoo currently comes at the cost of significant computational slowdown, due to its expensive internal operations. In this paper, we take a significant step to address this shortcoming by proposing \method (for \textbf{D}istributed \textbf{A}ccelerated \textbf{SH}ampoo), a faster implementation of Distributed Shampoo based on two main new techniques: First, we show that preconditioner blocks can be stacked into 3D tensors to significantly improve GPU utilization; second, we introduce the Newton-DB iteration and the Chebyshev polynomial approximations as novel and faster approaches for computing the inverse matrix roots required by Shampoo. Along with these algorithmic contributions, we provide a first in-depth analysis of how matrix scaling critically affects Shampoo convergence. On the practical side, our GPU-aware implementation achieves up to $4.83\times$ faster optimizer steps compared to the well-optimized Distributed Shampoo, while Newton-DB attains the lowest validation perplexity per iteration among all tested methods. Our code is available at https://github.com/IST-DASLab/DASH.