JEDI: Joint Embedding Diffusion World Model for Online Model-Based Reinforcement Learning

📅 2026-05-13
📈 Citations: 0
Influential: 0
📄 PDF

career value

214K/year
🤖 AI Summary
This work addresses the efficiency–performance trade-off in existing diffusion-based world models for online reinforcement learning, where pixel-level diffusion incurs high computational costs and latent diffusion approaches relying on pretrained representations suffer from suboptimal performance and lack end-to-end training. To overcome these limitations, we propose JEDI—the first online, end-to-end latent diffusion world model—which uniquely integrates Joint Embedding Predictive Architecture (JEPA) with conditional latent diffusion denoising to directly learn future latents from the denoising loss, eliminating both reconstruction objectives and pretraining. Theoretical analysis reveals that JEPA inherently exhibits a predictive information bottleneck, a structure naturally mirrored by diffusion denoising. Empirically, JEDI outperforms decoupled-training baselines on Atari100k, reduces VRAM consumption by 43% compared to pixel diffusion, achieves over 3× faster sampling, and accelerates training by 2.5×.
📝 Abstract
Diffusion world models have recently become competitive for online model-based reinforcement learning, but current approaches expose a tension: pixel diffusion is effective but computationally expensive while the latest latent diffusion approach improves efficiency yet performs subpar. The latter also relies on separately trained latents rather than the end-to-end world-model objectives that have driven much of modern MBRL progress. In particular, JEPA-style predictive representation learning has emerged as an especially promising direction for world modeling and MBRL. Concurrently, diffusion-style objectives have gained traction across multiple domains, with iterative refinement as a promising approach for multimodal and stochastic targets. Taken together, these trends motivate Joint Embedding DIffusion (JEDI), the first online end-to-end latent diffusion world model. JEDI learns its latent space directly from the diffusion denoising loss with a JEPA framework, using denoising to learn and predict future latents rather than relying on reconstruction and pretrained models. We provide a theoretical motivation showing that conventional JEPA objectives induce a predictive information bottleneck, and that conditional diffusion denoising admits a closely related predictive-compression decomposition. Empirically, JEDI is competitive on Atari100k and outperforms the baseline with seperately trained latents where directly comparable. Relative to the pixel diffusion baseline, JEDI uses 43% less VRAM, over 3$\times$ faster world-model sampling, and 2.5$\times$ faster training. JEDI also exhibits a markedly different task-level performance profile from the pixel baseline, suggesting that end-to-end predictive latents change more than compute alone.
Problem

Research questions and friction points this paper is trying to address.

diffusion world models
model-based reinforcement learning
latent representations
online learning
predictive modeling
Innovation

Methods, ideas, or system contributions that make the work stand out.

latent diffusion
JEPA
online MBRL
predictive representation learning
end-to-end world model
🔎 Similar Papers