🤖 AI Summary
This work addresses the memory bottleneck and gradient noise challenges in large language model training caused by high-dimensional parameter spaces. The authors propose a bias-free low-rank gradient estimation method based on random projections onto low-dimensional subspaces. By introducing an optimal random projection distribution—specifically the Haar–Stiefel distribution—and integrating it with constrained functional optimization, they derive a gradient estimator that minimizes variance. This approach substantially reduces the memory footprint of backpropagation while preserving training accuracy. Experimental results demonstrate significant memory savings: peak GPU memory during RoBERTa-large fine-tuning drops from 16.7 GB to 3.83 GB, and the method outperforms conventional approaches in pretraining LLaMA-family models, enabling efficient and scalable large-model training.
📝 Abstract
Large language model (LLM) training is often bottlenecked by memory constraints and stochastic gradient noise in extremely high-dimensional parameter spaces. Motivated by empirical evidence that many LLM gradient matrices are effectively low-rank during training, we present an unbiased, memory-efficient, low-rank matrix estimator with the lowest variance that is applicable across common stochastic gradient estimation paradigms. The core idea is to project a high-dimensional stochastic gradient estimator onto a random low-dimensional subspace and lift it back, reducing memory while keeping the estimator unbiased and controlling mean-squared error via an optimally designed projection distribution, including Haar--Stiefel projections. The projection distribution is derived by solving a constrained functional optimization problem, yielding an optimal random projector that guides algorithm design. Empirically, the resulting low-rank gradient estimators deliver both practical memory savings and improved training behavior. In RoBERTa-large fine-tuning, our method attains the lowest peak GPU memory among compared methods (e.g., 3.83GB versus 16.7GB for full BP) while remaining competitive in accuracy; in autoregressive LLM pretraining (LLaMA-20M/60M/100M), our method outperforms the traditional methods, supporting the benefit of the proposed optimal projection strategy.