🤖 AI Summary
Diffusion language models suffer from a mismatch between training and inference trajectories in parallel decoding, which limits both generation efficiency and quality. This work proposes TRIMS, a novel framework that, for the first time, integrates trajectory-ranking instructions into masked diffusion training. TRIMS leverages a lightweight autoregressive teacher model to provide trajectory supervision signals that guide a more optimal token revelation order. By employing trajectory-aware masking strategies and supervised fine-tuning, the method significantly improves the trade-off between accuracy and parallelism on both LLaDA and Dream models. Remarkably, TRIMS achieves performance comparable to knowledge distillation approaches while avoiding their substantial training overhead.
📝 Abstract
Diffusion language models (DLMs) offer a promising path toward low-latency generation through parallel decoding, but their practical efficiency depends heavily on the decoding trajectory. In practice, this advantage often fails to fully materialize because standard training does not provide explicit supervision over token reveal order, creating a train-inference mismatch that leads to suboptimal decoding behavior. We propose Trajectory-Ranked Instruction Masked Supervision (TRIMS), a simple trajectory-guided supervised fine-tuning framework that injects trajectory supervision into standard Masked Diffusion Language Model (MDLM) training with minimal overhead. Instead of relying on costly DLM-based distillation, TRIMS uses lightweight signals from an autoregressive teacher to guide a trajectory-aware masking strategy, encouraging the model to learn more effective decoding orders. Experiments on LLaDA and Dream across math and coding benchmarks show that TRIMS significantly improves the accuracy-parallelism trade-off over both standard MDLM training and train-free acceleration baselines, while achieving competitive performance with prior distillation-based approaches at substantially lower training cost. Further analysis shows that TRIMS leads to better decoding trajectories, validating the effectiveness of trajectory-guided supervision for DLMs.