MuonBP: Faster Muon via Block-Periodic Orthogonalization

📅 2025-10-19
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Gradient orthogonalization in model-parallel training incurs 5–10% communication overhead due to gather/scatter operations on sharded gradient matrices across devices, significantly degrading throughput. To address this, we propose Block-wise Periodic Orthogonalization (BPO): local block-wise orthogonalization of gradients on each device minimizes inter-device communication, while infrequent global orthogonalization ensures convergence. BPO integrates a two-step learning rate schedule with adaptive scaling, theoretically balancing stability and efficiency without altering optimizer architecture—thus remaining fully compatible with first-order momentum methods such as AdamW. Furthermore, it optimizes communication patterns specifically for model parallelism. Evaluated on an 8B-parameter model, BPO achieves an 8% throughput improvement over Muon while preserving baseline convergence behavior, maintaining identical iteration complexity, and exhibiting negligible hyperparameter sensitivity.

Technology Category

Application Category

📝 Abstract
Gradient orthogonalization is a simple strategy that shows great utility in speeding up gradient descent. The Muon optimizer (Jordan, Jin, et al., 2024) combines gradient orthogonalization with first-order momentum and achieves significant improvement in data efficiency over Adam/AdamW (Loshchilov and Hutter, 2019) for language model training. However, when using model parallelism, gradient orthogonalization introduces additional overhead compared to coordinate-wise optimizers (such as AdamW) due to additional gather and scatter operations on gradient matrix shards from different devices. This additional communication can amount to a throughput hit of 5%-10% compared to Adam/AdamW. To remedy this, we propose Muon with Block-Periodic Orthogonalization (MuonBP), which applies orthogonalization independently to matrix shards on each device and periodically performs full orthogonalization to maintain training stability at scale. We show how to adjust the learning rate from the baseline to MuonBP and give convergence guarantees for this algorithm. Crucially, our theory dictates that we use two stepsizes: one for the blockwise orthogonalization steps, and one for the full orthogonalization steps. Our method is simple, requires minimal hyperparameter adjustments, and achieves competitive iteration complexity compared with baseline Muon while providing per-iteration throughput comparable to coordinate-wise methods such as AdamW. When training an 8B model with eight-way tensor parallelism and ZeRO optimizer state sharding, MuonBP achieves 8% throughput increase compared to Muon with no degradation in performance.
Problem

Research questions and friction points this paper is trying to address.

Reduces communication overhead in distributed gradient orthogonalization
Maintains training stability with periodic full orthogonalization steps
Improves optimizer throughput while preserving convergence guarantees
Innovation

Methods, ideas, or system contributions that make the work stand out.

Block-periodic orthogonalization reduces communication overhead
Independent shard orthogonalization with periodic full synchronization
Dual learning rates for blockwise and full orthogonalization steps