🤖 AI Summary
To address the poor performance of diffusion model training in federated learning under non-IID data, this paper proposes FedDDPM: a framework where the server leverages locally trained diffusion models uploaded by clients to generate auxiliary data approximating the global data distribution, followed by lightweight fine-tuning of the aggregated global model after each FedAvg round. We further introduce FedDDPM+, incorporating dynamic heterogeneity detection and a single-step adaptive correction mechanism. This work pioneers the use of diffusion models’ generative capability for distribution calibration in federated settings, proposing an auxiliary-data-driven post-aggregation optimization paradigm and a theoretically analyzable lightweight correction algorithm. Experiments on MNIST, CIFAR-10, and CIFAR-100 demonstrate that our method achieves faster convergence, stronger generalization, and significantly lower communication and computational overhead compared to existing federated diffusion learning approaches.
📝 Abstract
Diffusion models are powerful generative models that can produce highly realistic samples for various tasks. Typically, these models are constructed using centralized, independently and identically distributed (IID) training data. However, in practical scenarios, data is often distributed across multiple clients and frequently manifests non-IID characteristics. Federated Learning (FL) can leverage this distributed data to train diffusion models, but the performance of existing FL methods is unsatisfactory in non-IID scenarios. To address this, we propose FedDDPM-Federated Learning with Denoising Diffusion Probabilistic Models, which leverages the data generative capability of diffusion models to facilitate model training. In particular, the server uses well-trained local diffusion models uploaded by each client before FL training to generate auxiliary data that can approximately represent the global data distribution. Following each round of model aggregation, the server further optimizes the global model using the auxiliary dataset to alleviate the impact of heterogeneous data on model performance. We provide a rigorous convergence analysis of FedDDPM and propose an enhanced algorithm, FedDDPM+, to reduce training overheads. FedDDPM+ detects instances of slow model learning and performs a one-shot correction using the auxiliary dataset. Experimental results validate that our proposed algorithms outperform the state-of-the-art FL algorithms on the MNIST, CIFAR10 and CIFAR100 datasets.