MXNorm: Reusing MXFP block scales for efficient tensor normalisation

πŸ“… 2026-03-13
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
This work addresses the inefficiency of normalization operations in deep learning when deployed on low-precision accelerators such as MXFP8, which typically require high-precision computations. To overcome this limitation, the authors propose MXNorm, a plug-and-play replacement for RMSNorm that leverages the block-wise scales already computed during MXFP format conversion to approximate the root mean square, thereby eliminating the need for additional high-precision reduction operations. When integrated with torch.compile optimizations, MXNorm achieves nearly lossless accuracy in pretraining Llama 3 models ranging from 125M to 8B parameters, while delivering up to a 2.4Γ— speedup in kernel execution. End-to-end training acceleration reaches 1.3% with MXFP8 and 2.6% with NVFP4 on the Llama 3 8B model.

Technology Category

Application Category

πŸ“ Abstract
Matrix multiplication performance has long been the major bottleneck to scaling deep learning workloads, which has stimulated the design of new accelerators that use increasingly low-precision number formats. However, improvements in matrix multiplication performance have far outstripped improvements in performance on reductions and elementwise computations, which are still being performed in higher precision. In this work, we propose MXNorm, a drop-in replacement for RMSNorm that estimates the RMS using only the block scales calculated as part of the MXFP8 cast and enables a 32x decrease in the size of reduction needed for normalization. We validate our approximation method on pre-training of Llama 3 models of 125M, 1B and 8B parameters, finding minimal loss of training accuracy compared to a baseline using RMSNorm with MXFP8 matmuls. We also show practical kernel speedups using only torch.compile of up to 2.4x for MXNorm over RMSNorm, corresponding to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
Problem

Research questions and friction points this paper is trying to address.

tensor normalization
low-precision training
reduction bottleneck
RMSNorm
MXFP8
Innovation

Methods, ideas, or system contributions that make the work stand out.

MXNorm
low-precision computing
tensor normalization
block scales
RMSNorm
πŸ”Ž Similar Papers
No similar papers found.
C
Callum McLean
Graphcore
L
Luke Y. Prince
Graphcore
A
Alexandre Payot
Graphcore
P
Paul BalanΓ§a
Graphcore
Carlo Luschi
Carlo Luschi
VP & Head of Research, Graphcore
Artificial IntelligenceNeural NetworksDeep LearningGraph Learning