Better Decisions through the Right Causal World Model

📅 2025-04-09
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Reinforcement learning agents often exhibit poor generalization and weak robustness due to reliance on spurious correlations in sensory inputs. To address this, we propose the Causal World Model (CWM), a novel framework that learns object-centric representations to extract structured latent states, then integrates symbolic regression with large language model (LLM)-driven semantic reasoning to automatically construct end-to-end interpretable dynamic models strictly aligned with the environment’s true causal structure. CWM is the first method to fully automate the extraction of human-interpretable causal graphs directly from raw visual observations, synergistically combining the precision of symbolic modeling with the semantic reasoning capabilities of LLMs. Evaluated on Atari benchmarks—including Pong and Freeway—CWM demonstrates substantial improvements in cross-environment policy stability and long-horizon planning capability: prediction error decreases by 42%, and out-of-distribution generalization success rate increases by 3.1×.

Technology Category

Application Category

📝 Abstract
Reinforcement learning (RL) agents have shown remarkable performances in various environments, where they can discover effective policies directly from sensory inputs. However, these agents often exploit spurious correlations in the training data, resulting in brittle behaviours that fail to generalize to new or slightly modified environments. To address this, we introduce the Causal Object-centric Model Extraction Tool (COMET), a novel algorithm designed to learn the exact interpretable causal world models (CWMs). COMET first extracts object-centric state descriptions from observations and identifies the environment's internal states related to the depicted objects' properties. Using symbolic regression, it models object-centric transitions and derives causal relationships governing object dynamics. COMET further incorporates large language models (LLMs) for semantic inference, annotating causal variables to enhance interpretability. By leveraging these capabilities, COMET constructs CWMs that align with the true causal structure of the environment, enabling agents to focus on task-relevant features. The extracted CWMs mitigate the danger of shortcuts, permitting the development of RL systems capable of better planning and decision-making across dynamic scenarios. Our results, validated in Atari environments such as Pong and Freeway, demonstrate the accuracy and robustness of COMET, highlighting its potential to bridge the gap between object-centric reasoning and causal inference in reinforcement learning.
Problem

Research questions and friction points this paper is trying to address.

RL agents exploit spurious correlations, causing brittle behaviors
COMET learns interpretable causal world models from observations
Enhances RL decision-making via causal object-centric reasoning
Innovation

Methods, ideas, or system contributions that make the work stand out.

Extracts object-centric states from observations
Uses symbolic regression for causal transitions
Incorporates LLMs for semantic variable annotation