Training LLMs with MXFP4

📅 2025-02-27
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address the significant model quality degradation caused by MXFP4 low-precision training, this paper proposes the first near-lossless MXFP4 training framework for large language models. Our method introduces a randomized Hadamard transform to constrain the variance of stochastic rounding (SR), with theoretical guarantees on tight bounds for gradient estimation bias and variance—ensuring backward pass stability and unbiasedness. The approach integrates MXFP4 quantization, SR, Hadamard preprocessing, and a mixed-precision training infrastructure, while optimizing GEMM computation. Evaluated on a 6.7B-parameter GPT model, it achieves final model quality virtually indistinguishable from BF16 baselines. Backward pass throughput improves by >1.3× over FP8 and >1.7× over BF16. Over 50% of training FLOPs are executed in MXFP4, substantially reducing computational cost and memory bandwidth requirements.

Technology Category

Application Category

📝 Abstract
Low precision (LP) datatypes such as MXFP4 can accelerate matrix multiplications (GEMMs) and reduce training costs. However, directly using MXFP4 instead of BF16 during training significantly degrades model quality. In this work, we present the first near-lossless training recipe that uses MXFP4 GEMMs, which are $2 imes$ faster than FP8 on supported hardware. Our key insight is to compute unbiased gradient estimates with stochastic rounding (SR), resulting in more accurate model updates. However, directly applying SR to MXFP4 can result in high variance from block-level outliers, harming convergence. To overcome this, we use the random Hadamard tranform to theoretically bound the variance of SR. We train GPT models up to 6.7B parameters and find that our method induces minimal degradation over mixed-precision BF16 training. Our recipe computes $>1/2$ the training FLOPs in MXFP4, enabling an estimated speedup of $>1.3 imes$ over FP8 and $>1.7 imes$ over BF16 during backpropagation.
Problem

Research questions and friction points this paper is trying to address.

Achieving near-lossless training with MXFP4 precision.
Reducing variance in stochastic rounding for accurate updates.
Enabling faster training speeds compared to FP8 and BF16.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Uses MXFP4 for faster matrix multiplications
Applies stochastic rounding for accurate updates
Employs random Hadamard transform to reduce variance
🔎 Similar Papers
No similar papers found.