Learning Long-Context Diffusion Policies via Past-Token Prediction

📅 2025-05-14
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
In long-horizon robotic imitation learning, diffusion policies suffer from historical information loss, escalating memory overhead, and performance degradation as context length increases. To address this, we propose Past-Token Prediction (PTP), an auxiliary task integrated with multi-stage training and test-time self-verifying inference to significantly enhance modeling of long-range temporal dependencies. Our method jointly optimizes cached embedding fine-tuning, vision encoder pretraining, and diffusion policy learning. Evaluated on four real-world and six simulated tasks, our approach improves long-context policy success rates by 3× and accelerates training by over 10×. PTP is the first explicit action-history modeling mechanism designed specifically for improving the long-context robustness of diffusion policies—uniquely combining theoretical simplicity with practical engineering viability.

Technology Category

Application Category

📝 Abstract
Reasoning over long sequences of observations and actions is essential for many robotic tasks. Yet, learning effective long-context policies from demonstrations remains challenging. As context length increases, training becomes increasingly expensive due to rising memory demands, and policy performance often degrades as a result of spurious correlations. Recent methods typically sidestep these issues by truncating context length, discarding historical information that may be critical for subsequent decisions. In this paper, we propose an alternative approach that explicitly regularizes the retention of past information. We first revisit the copycat problem in imitation learning and identify an opposite challenge in recent diffusion policies: rather than over-relying on prior actions, they often fail to capture essential dependencies between past and future actions. To address this, we introduce Past-Token Prediction (PTP), an auxiliary task in which the policy learns to predict past action tokens alongside future ones. This regularization significantly improves temporal modeling in the policy head, with minimal reliance on visual representations. Building on this observation, we further introduce a multistage training strategy: pre-train the visual encoder with short contexts, and fine-tune the policy head using cached long-context embeddings. This strategy preserves the benefits of PTP while greatly reducing memory and computational overhead. Finally, we extend PTP into a self-verification mechanism at test time, enabling the policy to score and select candidates consistent with past actions during inference. Experiments across four real-world and six simulated tasks demonstrate that our proposed method improves the performance of long-context diffusion policies by 3x and accelerates policy training by more than 10x.
Problem

Research questions and friction points this paper is trying to address.

Learning effective long-context policies from demonstrations
Addressing spurious correlations in long-context policy training
Retaining critical historical information for decision-making
Innovation

Methods, ideas, or system contributions that make the work stand out.

Past-Token Prediction for temporal modeling
Multistage training with cached embeddings
Self-verification mechanism during inference
🔎 Similar Papers
No similar papers found.