TD-JEPA: Latent-predictive Representations for Zero-Shot Reinforcement Learning

📅 2025-10-01
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses zero-shot generalization in offline reward-free reinforcement learning. We propose TD-JEPA, a framework that models multi-step policy dynamics in latent space via temporal-difference (TD) learning, enabling representation learning without online interaction or task-specific rewards. TD-JEPA integrates latent-variable prediction, policy-conditioned multi-step forecasting, and explicit state/task encoders, while jointly optimizing a parameterized policy in latent space—effectively mitigating representation collapse and facilitating successor feature recovery. As the first method to deeply unify TD learning with latent-space predictive modeling, TD-JEPA achieves state-of-the-art or superior performance across 13 diverse locomotion, navigation, and manipulation tasks from ExoRL and OGBench. Notably, it excels in zero-shot reward optimization from pixel inputs, demonstrating strong generalization to unseen tasks without fine-tuning.

Technology Category

Application Category

📝 Abstract
Latent prediction--where agents learn by predicting their own latents--has emerged as a powerful paradigm for training general representations in machine learning. In reinforcement learning (RL), this approach has been explored to define auxiliary losses for a variety of settings, including reward-based and unsupervised RL, behavior cloning, and world modeling. While existing methods are typically limited to single-task learning, one-step prediction, or on-policy trajectory data, we show that temporal difference (TD) learning enables learning representations predictive of long-term latent dynamics across multiple policies from offline, reward-free transitions. Building on this, we introduce TD-JEPA, which leverages TD-based latent-predictive representations into unsupervised RL. TD-JEPA trains explicit state and task encoders, a policy-conditioned multi-step predictor, and a set of parameterized policies directly in latent space. This enables zero-shot optimization of any reward function at test time. Theoretically, we show that an idealized variant of TD-JEPA avoids collapse with proper initialization, and learns encoders that capture a low-rank factorization of long-term policy dynamics, while the predictor recovers their successor features in latent space. Empirically, TD-JEPA matches or outperforms state-of-the-art baselines on locomotion, navigation, and manipulation tasks across 13 datasets in ExoRL and OGBench, especially in the challenging setting of zero-shot RL from pixels.
Problem

Research questions and friction points this paper is trying to address.

Learning latent-predictive representations from offline reward-free transitions
Enabling zero-shot optimization of any reward function at test time
Training policies directly in latent space for unsupervised reinforcement learning
Innovation

Methods, ideas, or system contributions that make the work stand out.

TD-based latent-predictive representations for long-term dynamics
Unsupervised training of state and task encoders in latent space
Zero-shot reward optimization via policy-conditioned multi-step prediction
🔎 Similar Papers
No similar papers found.