🤖 AI Summary
Neural network training is fundamentally a large-scale matrix optimization problem, yet existing methods often neglect the intrinsic matrix structure of model parameters. This paper proposes a low-rank orthogonalization optimization framework—the first to integrate low-rank matrix decomposition and orthogonalization directly into gradient updates, combining matrix-sign gradient descent with a low-rank variant of the Muon optimizer. By exploiting the empirically observed low-rank structure of gradients, our method achieves significant computational efficiency gains while preserving theoretical rigor. It outperforms fine-tuned baseline Muon in GPT-2 and LLaMA pretraining. We provide rigorous convergence analysis, proving superior iteration complexity and robust convergence under heavy-tailed gradient noise. This work establishes a novel matrix-aware optimization paradigm for large foundation models—balancing efficiency, theoretical soundness, and interpretability.
📝 Abstract
Neural network (NN) training is inherently a large-scale matrix optimization problem, yet the matrix structure of NN parameters has long been overlooked. Recently, the optimizer Muon cite{jordanmuon}, which explicitly exploits this structure, has gained significant attention for its strong performance in foundation model training. A key component contributing to Muon's success is matrix orthogonalization. In this paper, we propose {it low-rank orthogonalization}, which explicitly leverages the low-rank nature of gradients during NN training. Building on this, we propose low-rank matrix-signed gradient descent and a low-rank variant of Muon. Our numerical experiments demonstrate the superior performance of low-rank orthogonalization, with the low-rank Muon achieving promising results in GPT-2 and LLaMA pretraining -- surpassing the performance of the carefully tuned vanilla Muon. Theoretically, we establish the iteration complexity of the low-rank matrix-signed gradient descent for finding an approximate stationary solution, as well as that of low-rank Muon for finding an approximate stochastic stationary solution under heavy-tailed noise.