🤖 AI Summary
Distributed training of foundation models suffers from slow convergence and stringent theoretical assumptions—such as bounded gradient variance or strong convexity—under data heterogeneity.
Method: We propose PISA, a novel optimizer integrating preconditioning techniques with an inexact stochastic ADMM framework. PISA requires only gradient Lipschitz continuity—a significantly weaker assumption—and unifies compatibility with second-moment estimation mechanisms (e.g., Adam, AdaGrad). It supports large-scale parallelism and inexact subproblem solving.
Results: Extensive experiments across vision models, large language models, GANs, reinforcement learning, and RNNs demonstrate that PISA consistently outperforms SGD, Adam, and AdaBelief in both convergence speed and final accuracy. Notably, under heterogeneous data settings, PISA achieves superior stability and generalization performance.
📝 Abstract
The recent advancement of foundation models (FMs) has brought about a paradigm shift, revolutionizing various sectors worldwide. The popular optimizers used to train these models are stochastic gradient descent-based algorithms, which face inherent limitations, such as slow convergence and stringent assumptions for convergence. In particular, data heterogeneity arising from distributed settings poses significant challenges to their theoretical and numerical performance. This paper develops an algorithm, PISA ({P}reconditioned {I}nexact {S}tochastic {A}lternating Direction Method of Multipliers), which enables scalable parallel computing and supports various second-moment schemes. Grounded in rigorous theoretical guarantees, the algorithm converges under the sole assumption of Lipschitz continuity of the gradient, thereby removing the need for other conditions commonly imposed by stochastic methods. This capability enables PISA to tackle the challenge of data heterogeneity effectively. Comprehensive experimental evaluations for training or fine-tuning diverse FMs, including vision models, large language models, reinforcement learning models, generative adversarial networks, and recurrent neural networks, demonstrate its superior numerical performance compared to various state-of-the-art optimizers.