🤖 AI Summary
This work addresses the training instability of offline policy distillation (OPD) in long-horizon tasks, which stems from three interrelated failure modes induced by sampled tokens: reward signal imbalance, unreliable teacher guidance on student-generated prefixes, and distributional shifts caused by tokenizer or special token mismatches. The authors systematically analyze these issues from both estimation and implementation perspectives and propose a concise yet effective solution that stabilizes the training objective through teacher-guided top-K local support matching, combined with truncated reverse KL divergence, top-p rollout sampling, and masking of special tokens. This approach achieves significantly improved training stability and downstream performance while maintaining low gradient variance, consistently outperforming conventional sampled-token OPD in both single-task mathematical reasoning and multi-task joint training settings involving agents and mathematical reasoning.
📝 Abstract
On-policy distillation (OPD) is appealing for large language model (LLM) post-training because it evaluates teacher feedback on student-generated rollouts rather than fixed teacher traces. In long-horizon settings, however, the common sampled-token variant is fragile: it reduces distribution matching to a one-token signal and becomes increasingly unreliable as rollouts drift away from prefixes the teacher commonly visits. We revisit OPD from the estimator and implementation sides. Theoretically, token-level OPD is biased relative to sequence-level reverse-KL, but it has a much tighter worst-case variance bound; our toy study shows the same tradeoff empirically, with stronger future-reward coupling producing higher gradient variance and less stable learning. Empirically, we identify three failure modes of sampled-token OPD: an imbalanced one-token signal, unreliable teacher guidance on student-generated prefixes, and distortions caused by tokenizer or special-token mismatch. We address these issues with teacher top-K local support matching, implemented as truncated reverse-KL with top-p rollout sampling and special-token masking. Across single-task math reasoning and multi-task agentic-plus-math training, this objective yields more stable optimization and better downstream performance than sampled-token OPD.