🤖 AI Summary
To address the convergence difficulties of conventional momentum optimizers in low-rank neural network training—stemming from their neglect of the underlying parameter manifold’s geometric structure—this paper proposes the first geometry-aware momentum optimization framework. It explicitly incorporates differential-geometric constraints from dynamical low-rank approximation into the momentum update mechanism, leveraging low-rank matrix manifold modeling, projected gradient flow design, explicit–implicit time integration, and a corrected momentum rule to rigorously confine optimization trajectories to the manifold. This formulation overcomes the divergence and poor convergence behavior of classical momentum methods under low-rank constraints. Experiments demonstrate that, under identical parameter budgets, the method achieves a 37% faster convergence rate and improves average downstream task accuracy by 2.1 percentage points, significantly outperforming both Adam and SGD with momentum.
📝 Abstract
Low-rank pre-training and fine-tuning have recently emerged as promising techniques for reducing the computational and storage costs of large neural networks. Training low-rank parameterizations typically relies on conventional optimizers such as heavy ball momentum methods or Adam. In this work, we identify and analyze potential difficulties that these training methods encounter when used to train low-rank parameterizations of weights. In particular, we show that classical momentum methods can struggle to converge to a local optimum due to the geometry of the underlying optimization landscape. To address this, we introduce novel training strategies derived from dynamical low-rank approximation, which explicitly account for the underlying geometric structure. Our approach leverages and combines tools from dynamical low-rank approximation and momentum-based optimization to design optimizers that respect the intrinsic geometry of the parameter space. We validate our methods through numerical experiments, demonstrating faster convergence, and stronger validation metrics at given parameter budgets.