🤖 AI Summary
This work investigates the underperformance of stochastic gradient descent (SGD) relative to adaptive optimizers like Adam in large language model pretraining, attributing it primarily to SGD’s inability to sustain a sufficiently large effective learning rate due to highly non-uniform gradient distributions and sporadic large gradient spikes in the output layer. Through theoretical analysis and large-scale experiments, the study identifies that in large-batch training, the combination of small gradient norms and large weight-to-gradient ratios critically constrains viable learning rates for SGD. To address this, the authors propose a simple yet effective gradient clipping mechanism that stabilizes SGD training. When applied to pretraining a 1B-parameter LLaMA model with 1M-token batches, this approach reduces the validation loss gap between SGD and Adam from over 50% to approximately 3.5%.
📝 Abstract
It is widely believed that stochastic gradient descent (SGD) performs significantly worse than adaptive optimizers such as Adam in pre-training Large Language Models (LLMs). Yet the underlying reason for this gap remains unclear. In this work, we attribute a large part of the discrepancy to SGD's inability to sustain learning rates comparable to Adam's much larger effective learning rates. Through empirical and theoretical analysis of LLM pre-training dynamics, we identify that training is characterized by small gradient norms and large weight-to-gradient ratios, an effect that becomes more pronounced with larger batch sizes typical in pre-training, necessitating such large effective learning rates. However, we find that output-layer gradient magnitudes become highly uneven across token classes, and that large gradient spikes frequently occur during training. Together, these effects severely restrict the admissible learning rate of SGD. Guided by this understanding, we show that simple clipping mechanisms that stabilize SGD at large learning rates enable it to recover most of Adam's performance. In our large-scale experiments, the validation loss gap between large-learning-rate SGD and Adam shrinks from more than 50% to only about 3.5% when pre-training a 1B-parameter LLaMA model with a 1M-token batch size.