Stochastic Rounding for LLM Training: Theory and Practice

📅 2025-02-27
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address numerical instability and suboptimal convergence efficiency in low-precision training of large language models (LLMs), this paper proposes a BF16 mixed-precision training framework augmented with stochastic rounding (SR). We establish, for the first time, the implicit regularization effect and theoretical convergence guarantees of SR within the Adam optimizer. Furthermore, we seamlessly extend BF16+SR to distributed training, enabling automatic, robust large-scale deployment. Empirical evaluation on 6.7B-model pretraining demonstrates that our method achieves significantly lower perplexity compared to standard BF16 and FP32 baselines, while improving throughput by 1.54× and reducing GPU memory consumption by 30%. This work introduces a new paradigm for LLM training—rigorously grounded in theory and validated in practice—that simultaneously delivers high accuracy, high efficiency, and low computational overhead.

Technology Category

Application Category

📝 Abstract
As the parameters of Large Language Models (LLMs) have scaled to hundreds of billions, the demand for efficient training methods -- balancing faster computation and reduced memory usage without sacrificing accuracy -- has become more critical than ever. In recent years, various mixed precision strategies, which involve different precision levels for optimization components, have been proposed to increase training speed with minimal accuracy degradation. However, these strategies often require manual adjustments and lack theoretical justification. In this work, we leverage stochastic rounding (SR) to address numerical errors of training with low-precision representation. We provide theoretical analyses of implicit regularization and convergence under the Adam optimizer when SR is utilized. With the insights from these analyses, we extend previous BF16 + SR strategy to be used in distributed settings, enhancing the stability and performance for large scale training. Empirical results from pre-training models with up to 6.7B parameters, for the first time, demonstrate that our BF16 with SR strategy outperforms (BF16, FP32) mixed precision strategies, achieving better validation perplexity, up to $1.54 imes$ higher throughput, and $30%$ less memory usage.
Problem

Research questions and friction points this paper is trying to address.

Address numerical errors in low-precision LLM training.
Enhance stability and performance in distributed training settings.
Improve training efficiency with better throughput and memory usage.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Stochastic rounding reduces numerical errors
BF16 + SR enhances distributed training stability
Achieves higher throughput with less memory
🔎 Similar Papers
No similar papers found.