Recipes for Pre-training LLMs with MXFP8

📅 2025-05-30
🏛️ arXiv.org
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address the instability and divergence observed when using the MXFP8 low-precision format in trillion-token LLM pretraining, this work proposes a dynamic scaling factor computation method based on round-to-infinity rounding, integrated with per-block micro-scaling and a customized rounding strategy, enabling stable training on NVIDIA Blackwell architecture. It achieves, for the first time, end-to-end MXFP8 pretraining of an 8B-parameter model on a 15T-token dataset, matching FP16 convergence behavior precisely. The approach reduces GPU memory footprint by ~40% and memory bandwidth requirements by ~35%, without requiring auxiliary precision recovery modules or mixed-precision fallbacks. This yields significantly improved hardware efficiency and scalability for ultra-large-scale training. To our knowledge, it constitutes the first high-fidelity, full-stage viable solution for low-bit LLM pretraining.

Technology Category

Application Category

📝 Abstract
Precision scaling - using fewer bits to represent model parameters and related tensors during pre-training - has emerged as a compelling technique for improving GPU efficiency without sacrificing accuracy. Microscaling (MX) formats in NVIDIA's latest Blackwell GPUs represent a major leap in enabling this precision scaling aspect. These formats combine narrow floating-point data types with per-block scaling factors, offering a fine-grained approach to quantizing tensors. Although MX-formats offer the promise of improved numeric stability compared to other reduced-precision representations, in practice they must be used carefully in order to successfully converge an LLM on a multi-trillion token dataset. In this paper, we show that the rounding mode suggested in OCP specification can lead to divergence when pre-training an LLM. We show an improved rounding mode, which uses round-to-infinity to compute scaling factors, enables successful pre-training in MXFP8 for an 8B model on 15T tokens.
Problem

Research questions and friction points this paper is trying to address.

Optimizing MXFP8 pre-training parameters for efficiency
Achieving BF16 accuracy with 8-bit quantization techniques
Enabling larger model training with reduced GPU memory usage
Innovation

Methods, ideas, or system contributions that make the work stand out.

MXFP8-E4M3 datatype for efficient pre-training
Per-block scaling factors enabling more tensor quantization
Number conversion algorithm matching BF16 training accuracy
🔎 Similar Papers
No similar papers found.