🤖 AI Summary
Federated learning (FL) suffers significant performance degradation under label shift—i.e., when clients’ label distributions diverge from that of the target domain. To address this, we propose FedPALS, the first server-side aggregation mechanism that explicitly incorporates a prior on the target-domain label distribution. By calibrating local client gradients via label distribution alignment, FedPALS enables unbiased model updates, thereby relaxing FL’s conventional strong assumption of identical label distributions across clients. Theoretically grounded in SGD, FedPALS supports weighted aggregation, requires no client-side labels or additional communication overhead, and preserves privacy. Extensive experiments on image classification tasks demonstrate that FedPALS consistently outperforms baselines such as FedAvg—particularly under challenging settings with sparse client participation and severe label distribution mismatch—effectively mitigating performance deterioration induced by label shift.
📝 Abstract
Federated learning enables multiple actors to collaboratively train models without sharing private data. This unlocks the potential for scaling machine learning to diverse applications. Existing algorithms for this task are well-justified when clients and the intended target domain share the same distribution of features and labels, but this assumption is often violated in real-world scenarios. One common violation is label shift, where the label distributions differ across clients or between clients and the target domain, which can significantly degrade model performance. To address this problem, we propose FedPALS, a novel model aggregation scheme that adapts to label shifts by leveraging knowledge of the target label distribution at the central server. Our approach ensures unbiased updates under stochastic gradient descent, ensuring robust generalization across clients with diverse, label-shifted data. Extensive experiments on image classification demonstrate that FedPALS consistently outperforms standard baselines by aligning model aggregation with the target domain. Our findings reveal that conventional federated learning methods suffer severely in cases of extreme client sparsity, highlighting the critical need for target-aware aggregation. FedPALS offers a principled and practical solution to mitigate label distribution mismatch, ensuring models trained in federated settings can generalize effectively to label-shifted target domains.