🤖 AI Summary
This work addresses the pervasive heavy-tailed class imbalance in language modeling and investigates the intrinsic mechanism underlying the superiority of adaptive optimizers (e.g., Adam) over standard gradient descent (GD). We identify that, under heavy-tailed label distributions, sign-based coordinate-wise descent—i.e., ℓ∞-norm-oriented descent—achieves faster convergence than ℓ₂-norm-normalized gradient descent. To formalize this insight, we propose a minimalistic next-token prediction theoretical model. Leveraging non-Euclidean steepest descent theory and coordinate-adaptive optimization frameworks, we provide the first rigorous convergence analysis demonstrating accelerated rates for sign descent under heavy-tailed data assumptions. Our analysis explicitly links optimizer behavior to distributional properties of the training data. This yields the first data-distribution-aware theoretical explanation for the empirical effectiveness of adaptive optimization in large language model pretraining.
📝 Abstract
Adaptive optimization methods (such as Adam) play a major role in LLM pretraining, significantly outperforming Gradient Descent (GD). Recent studies have proposed new smoothness assumptions on the loss function to explain the advantages of adaptive algorithms with structured preconditioners, e.g., coordinate-wise or layer-wise, and steepest descent methods w.r.t. non-euclidean norms, e.g., $ell_infty$ norm or spectral norm, over GD. However, it remains unclear how these smoothness assumptions manifest in language modelling tasks. In this work, we aim to analyze the benefit of $ell_infty$-norm descent (a.k.a. sign descent) directly from properties of the data distribution, namely, heavy-tailed class imbalance. We propose a minimal yet representative setting of next-token prediction, where we can provably show faster convergence of coordinate-wise algorithms such as Sign descent (steepest descent w.r.t. $ell_infty$ norm) over normalized GD (steepest descent w.r.t. to $ell_2$ norm) in the presence of heavy tail class imbalance.