π€ AI Summary
This work addresses the inefficiency of normalization operations in deep learning when deployed on low-precision accelerators such as MXFP8, which typically require high-precision computations. To overcome this limitation, the authors propose MXNorm, a plug-and-play replacement for RMSNorm that leverages the block-wise scales already computed during MXFP format conversion to approximate the root mean square, thereby eliminating the need for additional high-precision reduction operations. When integrated with torch.compile optimizations, MXNorm achieves nearly lossless accuracy in pretraining Llama 3 models ranging from 125M to 8B parameters, while delivering up to a 2.4Γ speedup in kernel execution. End-to-end training acceleration reaches 1.3% with MXFP8 and 2.6% with NVFP4 on the Llama 3 8B model.
π Abstract
Matrix multiplication performance has long been the major bottleneck to scaling deep learning workloads, which has stimulated the design of new accelerators that use increasingly low-precision number formats. However, improvements in matrix multiplication performance have far outstripped improvements in performance on reductions and elementwise computations, which are still being performed in higher precision. In this work, we propose MXNorm, a drop-in replacement for RMSNorm that estimates the RMS using only the block scales calculated as part of the MXFP8 cast and enables a 32x decrease in the size of reduction needed for normalization. We validate our approximation method on pre-training of Llama 3 models of 125M, 1B and 8B parameters, finding minimal loss of training accuracy compared to a baseline using RMSNorm with MXFP8 matmuls. We also show practical kernel speedups using only torch.compile of up to 2.4x for MXNorm over RMSNorm, corresponding to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.