A Fast and Flat Federated Learning Method via Weighted Momentum and Sharpness-Aware Minimization

πŸ“… 2025-11-26
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
To address communication inefficiency, slow convergence, and poor generalization in non-IID federated learning, this paper identifies two fundamental failure modes in existing momentum-augmented Sharpness-Aware Minimization (SAM) methods: (i) misalignment between local and global curvature, and (ii) momentum-induced oscillatory β€œecho” effects. We propose FedWMSAM, which introduces server-side momentum-guided global perturbation directions and an adaptive cosine-similarity-based coupling mechanism to achieve an efficient, single-backpropagation SAM approximation. Additionally, we design a two-stage momentum-SAM adaptive scheduling strategy and establish a novel non-IID convergence theory incorporating perturbation variance. Extensive experiments across multiple datasets and model architectures demonstrate that FedWMSAM significantly improves both convergence speed and generalization performance under realistic non-IID settings. The implementation is publicly available.

Technology Category

Application Category

πŸ“ Abstract
In federated learning (FL), models must emph{converge quickly} under tight communication budgets while emph{generalizing} across non-IID client distributions. These twin requirements have naturally led to two widely used techniques: client/server emph{momentum} to accelerate progress, and emph{sharpness-aware minimization} (SAM) to prefer flat solutions. However, simply combining momentum and SAM leaves two structural issues unresolved in non-IID FL. We identify and formalize two failure modes: emph{local-global curvature misalignment} (local SAM directions need not reflect the global loss geometry) and emph{momentum-echo oscillation} (late-stage instability caused by accumulated momentum). To our knowledge, these failure modes have not been jointly articulated and addressed in the FL literature. We propose extbf{FedWMSAM} to address both failure modes. First, we construct a momentum-guided global perturbation from server-aggregated momentum to align clients' SAM directions with the global descent geometry, enabling a emph{single-backprop} SAM approximation that preserves efficiency. Second, we couple momentum and SAM via a cosine-similarity adaptive rule, yielding an early-momentum, late-SAM two-phase training schedule. We provide a non-IID convergence bound that emph{explicitly models the perturbation-induced variance} $Οƒ_ρ^2=Οƒ^2+(Lρ)^2$ and its dependence on $(S, K, R, N)$ on the theory side. We conduct extensive experiments on multiple datasets and model architectures, and the results validate the effectiveness, adaptability, and robustness of our method, demonstrating its superiority in addressing the optimization challenges of Federated Learning. Our code is available at https://github.com/Huang-Yongzhi/NeurlPS_FedWMSAM.
Problem

Research questions and friction points this paper is trying to address.

Addresses non-IID data distribution challenges in federated learning optimization
Resolves momentum oscillation and curvature misalignment in federated training
Improves convergence speed and generalization under communication constraints
Innovation

Methods, ideas, or system contributions that make the work stand out.

Weighted momentum aligns SAM directions globally
Cosine-similarity rule enables two-phase training schedule
Single-backprop SAM approximation maintains computational efficiency
πŸ”Ž Similar Papers
No similar papers found.