🤖 AI Summary
In distributed data parallel (DDP) training, low-frequency communication schemes (e.g., Local SGD) suffer from degraded convergence when combined with adaptive optimizers—primarily because a single temporal momentum fails to accommodate infrequent local updates, amplifying gradient noise and impairing optimization stability.
Method: We propose Multi-Timescale Adaptive Optimization (MTAO), the first adaptive optimizer featuring a hierarchical momentum mechanism: a slow momentum for long-term gradient smoothing and a fast momentum for capturing local dynamics, augmented by gradient tracking to ensure unbiased local updates.
Contribution/Results: We establish theoretical convergence to first-order stationary points under non-convex assumptions. Empirically, MTAO eliminates the performance gap with fully synchronized DDP in language model pretraining—achieving lower perplexity—and delivers 6–27% end-to-end training speedup; notably, it accelerates training of a 720M-parameter model by 35%, significantly enhancing cross-datacenter training efficiency.
📝 Abstract
Training large models with distributed data parallelism (DDP) requires frequent communication of gradients across workers, which can saturate bandwidth. Infrequent communication strategies (e.g., Local SGD) reduce this overhead but, when applied to adaptive optimizers, often suffer a performance gap relative to fully synchronous DDP. We trace this gap to a time-scale mismatch: the optimizer's fast-moving momentum, tuned for frequent updates, decays too quickly to smooth gradients over long intervals, leading to noise-dominated optimization. To address this, we propose MT-DAO, a family of optimizers that employs multiple slow- and fast-moving first momenta or the gradient to track update dynamics across different time scales, for which we provide the first convergence guarantees. Empirically, for language-model pre-training, this eliminates the performance gap with DDP, outperforming infrequent-communication baselines in perplexity and reducing iso-token wall-clock time by 6-27% on Ethernet interconnects. At the 720M scale, MT-DAO reaches a target perplexity in 24% fewer steps and 35% less time than the single-momentum DDP baseline. MT-DAO enables effective cross-datacenter training and training over wide geographic areas.