Learn from your own latents and not from tokens: A sample-complexity theory

📅 2026-05-26
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the poor sample efficiency of traditional generative models compared to biological learning. It proposes a latent-variable prediction framework in which the model learns by predicting its own multi-level latent representations rather than raw tokens, substantially reducing reliance on large training datasets. Theoretically, for compositional data modeled by probabilistic context-free grammars, this approach recovers the underlying hierarchical tree structure with a sample complexity independent of the hierarchy depth \(L\)—requiring at most a logarithmic factor—dramatically outperforming supervised or token-level self-supervised methods that exhibit exponential sample requirements. The analysis further suggests that explicitly stacking multiple architectural layers, as in H-JEPA, may be redundant, thereby offering a new paradigm for efficient self-supervised learning.
📝 Abstract
Generative models, from diffusion models to large language models, achieve remarkable performance but at a cost in training data orders of magnitude larger than what biological learners require. An alternative paradigm has emerged in which networks are trained to predict their \emph{own} latent representations of related views or masked regions, as in data2vec and JEPA -- an idea related to predictive-coding accounts of the cortex. Despite strong empirical results, the theoretical understanding of these methods remains limited. Central questions include: by how much does latent prediction actually improve data efficiency? Is there a benefit to stacking such methods into multi-scale hierarchies? We answer both using as data a tractable probabilistic context-free grammar that captures the compositional structure of natural language and images. Such a grammar generates strings of visible tokens by recursively applying production rules along a tree of hidden symbols of depth $L$. For such data, supervised or token-level SSL require a number of samples \emph{exponential} in $L$ to recover the latent tree; we prove that latent prediction achieves this with a number of samples \emph{constant} in $L$, up to logarithmic factors. We confirm this bound with (i) a hierarchical clustering algorithm, (ii) an end-to-end neural network whose predictor-clusterer modules predict their own latents at each level via gradient descent, and (iii) the first sample-complexity analysis of data2vec, which we show implicitly performs hierarchical latent prediction. This suggests that explicit stacking such as H-JEPA is largely redundant.
Problem

Research questions and friction points this paper is trying to address.

sample complexity
latent prediction
self-supervised learning
hierarchical representation
data efficiency
Innovation

Methods, ideas, or system contributions that make the work stand out.

latent prediction
sample complexity
hierarchical representation
self-supervised learning
data2vec