🤖 AI Summary
This work addresses the contextual estimation bias in reinforcement learning caused by distributional mismatch between training and deployment in non-stationary environments. Within the framework of contextual Markov decision processes, the authors propose a method that implicitly attenuates the influence of outdated data by stochastically dropping historical transitions from the replay buffer, thereby enhancing the generalization capability of the context estimator. This mechanism improves the adaptability of the learned policy without requiring explicit identification of obsolete samples. Theoretical analysis demonstrates reduced test loss under distribution shift, while empirical results show a 30% reduction in robustness gap for multilayer perceptrons and an average 6% improvement for recurrent networks. Notably, a narrow MLP with five times fewer parameters outperforms a wider counterpart that does not employ the proposed mechanism.
📝 Abstract
Deploying reinforcement learning policies in the real world requires adapting to time-varying environments. We study this problem in the contextual Markov Decision Process (cMDP) framework, where a family of environments is indexed by a low-dimensional context unknown at test time. The standard approach decomposes the problem: train a so-called "universal policy" which assumes knowledge of the true context, then pair it with a context estimator which approximates context using the observed trajectory. We identify a simple, counterintuitive trick that substantially improves the estimator: randomly delete a fraction of the training buffer after each round. This works because data is collected across multiple rounds using progressively better policies, and older trajectories come from a different distribution than what the estimator will face at deployment time; random deletion creates an implicit exponential decay on older data while preserving diversity without requiring any explicit identification of which samples are stale. This reduces robustness gap by 30% for MLPs and by 6% on average for recurrent networks. Strikingly, it allows a narrow MLP with 5x fewer parameters to outperform a wide MLP trained without deletion. To understand when and why deletion helps, we analyze regularized empirical risk minimization with a mismatch between the train distribution and the distribution at deployment; in this idealized setting, we prove that removing a single uniformly random training point decreases expected test loss in expectation under mild conditions. For ridge regression we make this quantitative: deletion helps when the regularization coefficient is moderate and the signal-to-noise ratio (SNR) is sufficiently low, and, crucially, this SNR threshold gives a direct measure of how large the distribution mismatch between training and deployment must be for deletion to be beneficial.