Straight to Zero: Why Linearly Decaying the Learning Rate to Zero Works Best for LLMs

📅 2025-02-21
📈 Citations: 1
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the optimization of learning rate (LR) decay scheduling in large language model (LLM) training. We systematically evaluate diverse LR schedulers and find that linear decay to zero (D2Z) consistently outperforms mainstream alternatives—including cosine decay to 10%—under computation-optimal token-per-parameter (TPP) scaling. Theoretically, we provide the first analysis showing that D2Z achieves an optimal trade-off: accelerating escape from suboptimal initial parameter regions early in training while effectively suppressing gradient noise in later stages; we further reinterpret AdamW’s exponential moving average of weight updates through this lens. Empirically, under TPP-controlled scaling across models from 610M to 7B parameters, D2Z enables the 610M model to achieve lower loss at just 80 TPP—reducing compute cost by 60% versus a 10× decay baseline requiring 200 TPP—and delivers substantial training efficiency gains for larger models including Llama2-7B.

Technology Category

Application Category

📝 Abstract
LLMs are commonly trained with a learning rate (LR) warmup, followed by cosine decay to 10% of the maximum (10x decay). In a large-scale empirical study, we show that under an optimal peak LR, a simple linear decay-to-zero (D2Z) schedule consistently outperforms other schedules when training at compute-optimal dataset sizes. D2Z is superior across a range of model sizes, batch sizes, datasets, and vocabularies. Benefits increase as dataset size increases. Leveraging a novel interpretation of AdamW as an exponential moving average of weight updates, we show how linear D2Z optimally balances the demands of early training (moving away from initial conditions) and late training (averaging over more updates in order to mitigate gradient noise). In experiments, a 610M-parameter model trained for 80 tokens-per-parameter (TPP) using D2Z achieves lower loss than when trained for 200 TPP using 10x decay, corresponding to an astonishing 60% compute savings. Models such as Llama2-7B, trained for 286 TPP with 10x decay, could likely have saved a majority of compute by training with D2Z.
Problem

Research questions and friction points this paper is trying to address.

Optimizing learning rate schedules for LLMs
Linear decay-to-zero outperforms cosine decay
Reducing computational costs in model training
Innovation

Methods, ideas, or system contributions that make the work stand out.

Linear decay-to-zero schedule
Optimal peak learning rate
AdamW exponential moving average
🔎 Similar Papers
No similar papers found.