🤖 AI Summary
This work identifies the root cause of gradient norm surges observed in the late stages of large language model (LLM) training: an implicit coupling among weight decay, normalization layers (e.g., LayerNorm), and learning rate scheduling. We propose the first interpretability-oriented analytical framework to quantitatively characterize how this coupling destabilizes parameter update directions and triggers gradient explosion. Building on this insight, we design a lightweight correction strategy—modifying only *where* and *how* weight decay is applied—requiring no additional computation, architectural changes, or hyperparameter tuning. Extensive experiments across multiple LLM training tasks demonstrate that our method ensures stable gradient norms throughout training, reduces training loss significantly, and lowers validation perplexity by 0.8–1.2 on average. The approach substantially improves training stability and final model performance without compromising efficiency or scalability.
📝 Abstract
During long-duration Large Language Model (LLM) training runs the gradient norm increases rapidly near the end of training. In this short note, we show that this increase is due to an unintended interaction between weight decay, normalization layers, and the learning rate schedule. We propose a simple correction that fixes this behavior while also resulting in lower loss values throughout training.