🤖 AI Summary
In heterogeneous federated reinforcement learning, poor global policy generalization, severe distributional shift, and policy conflicts pose significant challenges. To address these issues, this paper proposes FedWB—a novel federated RL algorithm that enables clients to independently train DQNs in heterogeneous environments (e.g., CartPole with varying pole lengths) and introduces the Wasserstein barycenter for model aggregation. Unlike conventional gradient- or parameter-averaging methods (e.g., FedAvg), FedWB computes a geometric center in parameter space, enabling data-free, gradient-free global model fusion while preserving structural consistency across heterogeneous local policies. This approach mitigates performance degradation inherent in naive averaging under heterogeneity. Experiments on multi-pole-length CartPole demonstrate that the globally aggregated DQN via FedWB achieves over 95% stable control success across all heterogeneous settings—substantially outperforming FedAvg, FedProx, and other baselines in both robustness and generalization.
📝 Abstract
In this paper, we first propose a novel algorithm for model fusion that leverages Wasserstein barycenters in training a global Deep Neural Network (DNN) in a distributed architecture. To this end, we divide the dataset into equal parts that are fed to"agents"who have identical deep neural networks and train only over the dataset fed to them (known as the local dataset). After some training iterations, we perform an aggregation step where we combine the weight parameters of all neural networks using Wasserstein barycenters. These steps form the proposed algorithm referred to as FedWB. Moreover, we leverage the processes created in the first part of the paper to develop an algorithm to tackle Heterogeneous Federated Reinforcement Learning (HFRL). Our test experiment is the CartPole toy problem, where we vary the lengths of the poles to create heterogeneous environments. We train a deep Q-Network (DQN) in each environment to learn to control each cart, while occasionally performing a global aggregation step to generalize the local models; the end outcome is a global DQN that functions across all environments.