🤖 AI Summary
To address the prohibitively high memory overhead of momentum-based optimizers in large language model training, this paper proposes AdaPM—an adaptive partial momentum optimization method leveraging non-uniform momentum allocation and bias correction. Its core innovation lies in dynamically identifying and retaining momentum states only for critical parameters, coupled with gradient-statistics-driven sparse momentum updates and exact bias correction, thereby drastically reducing optimizer state storage. Experiments across the full training pipeline—including pretraining, fine-tuning, and RLHF—on language models ranging from 60M to 1.5B parameters demonstrate that AdaPM reduces momentum memory usage by over 90% and compresses overall optimizer state by up to 95%. For GPT-2 (1.5B), it saves more than 30% GPU-hours during pretraining, without compromising convergence speed or final model performance. Moreover, AdaPM natively supports second-order statistic compression, establishing a new paradigm for memory-efficient, scalable optimization in large-scale training.
📝 Abstract
In the training of large language models, momentum is widely used and often demonstrated to achieve significant acceleration. However, storing momentum typically presents memory challenges. In this paper, we propose AdaPM, an adaptive training strategy that leverages partial momentum to implement a memory-efficient optimizer. To this end, AdaPM utilizes a non-uniform momentum design: for most blocks, full momentum is not necessary to preserve the performance of the optimization. In the momentum design of AdaPM, to mitigate the bias and performance loss caused by partial momentum, we enhance the partial momentum by a bias correction technique. Empirically, we verify that our approach reduces memory by over $90%$ in momentum while maintaining both efficiency and performance for pretraining various language models ranging from 60M to 1.5B, as well as for supervised fine-tuning and RLHF. AdaPM can further reduce memory by up to $95%$ in optimizer states by combining the memory-efficient technique on the second-order statistic, saving over $30%$ GPU hours for pretraining GPT-2 1.5B.