🤖 AI Summary
Joint Embedding Predictive Architecture (JEPA) is widely adopted in visual representation learning and model-based reinforcement learning, yet its susceptibility to representation collapse and inability to distinguish semantically nonequivalent observations—such as those arising from distinct dynamics or class labels—remains insufficiently understood.
Method: We propose a JEPA variant augmented with an auxiliary regression head, enabling end-to-end joint training of latent dynamics and the auxiliary task to explicitly enforce semantic separability in the encoder.
Contribution/Results: Theoretically, we establish the first “non-pathological collapse” theorem, revealing the auxiliary task’s role as a semantic anchor and providing formal design principles for JEPA encoders. Empirically, our method achieves zero-loss convergence on deterministic MDPs and significantly enhances representation richness and discriminability compared to decoupled training paradigms.
📝 Abstract
Joint-Embedding Predictive Architecture (JEPA) is increasingly used for visual representation learning and as a component in model-based RL, but its behavior remains poorly understood. We provide a theoretical characterization of a simple, practical JEPA variant that has an auxiliary regression head trained jointly with latent dynamics. We prove a No Unhealthy Representation Collapse theorem: in deterministic MDPs, if training drives both the latent-transition consistency loss and the auxiliary regression loss to zero, then any pair of non-equivalent observations, i.e., those that do not have the same transition dynamics or auxiliary label, must map to distinct latent representations. Thus, the auxiliary task anchors which distinctions the representation must preserve. Controlled ablations in a counting environment corroborate the theory and show that training the JEPA model jointly with the auxiliary head generates a richer representation than training them separately. Our work indicates a path to improve JEPA encoders: training them with an auxiliary function that, together with the transition dynamics, encodes the right equivalence relations.