🤖 AI Summary
To address excessive optimizer state memory overhead in large language model training, this paper proposes a correlation-aware low-rank gradient projection method. Unlike existing low-rank approaches that ignore statistical correlations between gradients and momentum, our method is the first to explicitly model the cross-step gradient–momentum covariance structure and jointly estimate their shared low-rank subspace. We design an SVD-free, lightweight projection update mechanism, integrated with dynamic rank adaptation and 8-bit quantization. On LLaMA-1B, our method reduces optimizer memory by 61% while matching AdamW’s perplexity; in fine-tuning LLaVA-v1.5-7B, it achieves 81% memory reduction, 4× faster training than GaLore, and superior accuracy. Key contributions include: (1) explicit modeling of cross-step gradient–momentum correlations; (2) covariance-driven joint low-rank projection; and (3) an efficient, SVD-free update framework.
📝 Abstract
Training large-scale neural networks in vision, and multimodal domains demands substantial memory resources, primarily due to the storage of optimizer states. While LoRA, a popular parameter-efficient method, reduces memory usage, it often suffers from suboptimal performance due to the constraints of low-rank updates. Low-rank gradient projection methods (e.g., GaLore, Flora) reduce optimizer memory by projecting gradients and moment estimates into low-rank spaces via singular value decomposition or random projection. However, they fail to account for inter-projection correlation, causing performance degradation, and their projection strategies often incur high computational costs. In this paper, we present COAP (Correlation-Aware Gradient Projection), a memory-efficient method that minimizes computational overhead while maintaining training performance. Evaluated across various vision, language, and multimodal tasks, COAP outperforms existing methods in both training speed and model performance. For LLaMA-1B, it reduces optimizer memory by 61% with only 2% additional time cost, achieving the same PPL as AdamW. With 8-bit quantization, COAP cuts optimizer memory by 81% and achieves 4x speedup over GaLore for LLaVA-v1.5-7B fine-tuning, while delivering higher accuracy.