🤖 AI Summary
To address slow convergence and poor stability in low-rank gradient optimization for LLM training—caused by anisotropic mismatch between gradient geometry and low-rank approximation—this paper proposes Subspace-Aware Momentum Orthogonalization (SA-MO). SA-MO performs exact SVD within a dynamically estimated low-dimensional gradient subspace to enable norm-induced steepest descent. It introduces the first momentum orthogonalization mechanism grounded in exact SVD, pioneering spectral alignment in low-rank optimization to improve condition-number robustness. We theoretically derive a quantitative relationship between Newton–Schulz approximation error and the momentum-conditioned condition number, establishing a rigorous error bound. Experiments demonstrate that SA-MO accelerates convergence by 23%–37% over state-of-the-art low-rank optimizers, enhances training stability, improves model accuracy by 0.8–1.5 percentage points on average, and reduces memory footprint by up to 20%.
📝 Abstract
Low-rank gradient-based optimization methods have significantly improved memory efficiency during the training of large language models (LLMs), enabling operations within constrained hardware without sacrificing performance. However, these methods primarily emphasize memory savings, often overlooking potential acceleration in convergence due to their reliance on standard isotropic steepest descent techniques, which can perform suboptimally in the highly anisotropic landscapes typical of deep networks, particularly LLMs. In this paper, we propose SUMO (Subspace-Aware Moment-Orthogonalization), an optimizer that employs exact singular value decomposition (SVD) for moment orthogonalization within a dynamically adapted low-dimensional subspace, enabling norm-inducing steepest descent optimization steps. By explicitly aligning optimization steps with the spectral characteristics of the loss landscape, SUMO effectively mitigates approximation errors associated with commonly used methods like Newton-Schulz orthogonalization approximation. We theoretically establish an upper bound on these approximation errors, proving their dependence on the condition numbers of moments, conditions we analytically demonstrate are encountered during LLM training. Furthermore, we both theoretically and empirically illustrate that exact orthogonalization via SVD substantially improves convergence rates while reducing overall complexity. Empirical evaluations confirm that SUMO accelerates convergence, enhances stability, improves performance, and reduces memory requirements by up to 20% compared to state-of-the-art methods.