🤖 AI Summary
This work investigates how pretrained initialization mitigates performance degradation caused by data heterogeneity in federated learning (FL). Focusing on two-layer CNNs under the FedAvg framework, we introduce the novel concepts of “filter alignment/misalignment” and theoretically prove that only misaligned filters are adversely affected by data heterogeneity—while pretrained initialization significantly reduces their count. Through rigorous analysis of gradient descent convergence bounds and generalization error bounds, we establish that pretrained initialization provably lowers the upper bound on test error. Experiments on synthetic and real-world FL benchmarks confirm that pretrained initialization not only accelerates convergence but also improves final model accuracy. Our key contributions are: (i) the first filter-level modeling of heterogeneity effects in FL; (ii) a theoretical explanation for how pretrained initialization improves generalization; and (iii) empirical validation supporting this theory.
📝 Abstract
Initializing with pre-trained models when learning on downstream tasks is becoming standard practice in machine learning. Several recent works explore the benefits of pre-trained initialization in a federated learning (FL) setting, where the downstream training is performed at the edge clients with heterogeneous data distribution. These works show that starting from a pre-trained model can substantially reduce the adverse impact of data heterogeneity on the test performance of a model trained in a federated setting, with no changes to the standard FedAvg training algorithm. In this work, we provide a deeper theoretical understanding of this phenomenon. To do so, we study the class of two-layer convolutional neural networks (CNNs) and provide bounds on the training error convergence and test error of such a network trained with FedAvg. We introduce the notion of aligned and misaligned filters at initialization and show that the data heterogeneity only affects learning on misaligned filters. Starting with a pre-trained model typically results in fewer misaligned filters at initialization, thus producing a lower test error even when the model is trained in a federated setting with data heterogeneity. Experiments in synthetic settings and practical FL training on CNNs verify our theoretical findings.