🤖 AI Summary
This work addresses the performance degradation of policies in offline reinforcement learning caused by unobserved confounding—such as perceptual discrepancies in pixel-based observations—by introducing causal inference into the flow matching framework for the first time. The authors propose a worst-case confounding-robust optimization objective that integrates flow matching for policy modeling, employs a deep discriminator to evaluate policy divergence, and is embedded within a Q-learning architecture to effectively mitigate confounding bias. Evaluated on 25 pixel-based tasks, the proposed method achieves a 20% higher success rate compared to state-of-the-art offline RL approaches, demonstrating significantly enhanced robustness against unobserved confounders.
📝 Abstract
Expressive policies based on flow-matching have been successfully applied in reinforcement learning (RL) more recently due to their ability to model complex action distributions from offline data. These algorithms build on standard policy gradients, which assume that there is no unmeasured confounding in the data. However, this condition does not necessarily hold for pixel-based demonstrations when a mismatch exists between the demonstrator's and the learner's sensory capabilities, leading to implicit confounding biases in offline data. We address the challenge by investigating the problem of confounded observations in offline RL from a causal perspective. We develop a novel causal offline RL objective that optimizes policies'worst-case performance that may arise due to confounding biases. Based on this new objective, we introduce a practical implementation that learns expressive flow-matching policies from confounded demonstrations, employing a deep discriminator to assess the discrepancy between the target policy and the nominal behavioral policy. Experiments across 25 pixel-based tasks demonstrate that our proposed confounding-robust augmentation procedure achieves a success rate 120\% that of confounding-unaware, state-of-the-art offline RL methods.