🤖 AI Summary
To address three key challenges in federated learning—high variance in second-moment estimation, client drift, and slowed convergence due to momentum reinitialization—caused by data heterogeneity when applying AdamW, this paper proposes FedAdamW, the first optimizer specifically designed for efficient federated training and fine-tuning of large language and vision models. Its core innovations include: (1) a local momentum correction mechanism to mitigate client drift; (2) decoupled weight decay to ensure globally consistent parameter updates; (3) mean aggregation of second-moment estimates to significantly reduce estimation variance; and (4) the first provable linear-speedup convergence guarantee without assuming data homogeneity, supported by a PAC-Bayes generalization bound. Extensive experiments on language and vision Transformers demonstrate that FedAdamW substantially reduces communication rounds and improves test accuracy. The implementation is publicly available.
📝 Abstract
AdamW has become one of the most effective optimizers for training large-scale models. We have also observed its effectiveness in the context of federated learning (FL). However, directly applying AdamW in federated learning settings poses significant challenges: (1) due to data heterogeneity, AdamW often yields high variance in the second-moment estimate $oldsymbol{v}$; (2) the local overfitting of AdamW may cause client drift; and (3) Reinitializing moment estimates ($oldsymbol{v}$, $oldsymbol{m}$) at each round slows down convergence. To address these challenges, we propose the first underline{Fed}erated underline{AdamW} algorithm, called exttt{FedAdamW}, for training and fine-tuning various large models. exttt{FedAdamW} aligns local updates with the global update using both a extbf{local correction mechanism} and decoupled weight decay to mitigate local overfitting. exttt{FedAdamW} efficiently aggregates the exttt{mean} of the second-moment estimates to reduce their variance and reinitialize them. Theoretically, we prove that exttt{FedAdamW} achieves a linear speedup convergence rate of $mathcal{O}(sqrt{(L Δσ_l^2)/(S K R ε^2)}+(L Δ)/R)$ without extbf{heterogeneity assumption}, where $S$ is the number of participating clients per round, $K$ is the number of local iterations, and $R$ is the total number of communication rounds. We also employ PAC-Bayesian generalization analysis to explain the effectiveness of decoupled weight decay in local training. Empirically, we validate the effectiveness of exttt{FedAdamW} on language and vision Transformer models. Compared to several baselines, exttt{FedAdamW} significantly reduces communication rounds and improves test accuracy. The code is available in https://github.com/junkangLiu0/FedAdamW.