🤖 AI Summary
Masked diffusion language models (MDLMs) suffer from inefficient sampling due to their inability to leverage KV caching. This work identifies that intermediate stages of the diffusion trajectory are most sensitive to model substitution and proposes a simple, architecture-agnostic dynamic scheduling strategy between large and small models: employing the smaller model during early and late diffusion steps—where it is robust to replacement—and invoking the larger model only during the sensitive intermediate phase. By integrating importance analysis based on loss and KL divergence with coarse-grained step-segment search, the method reduces FLOPs by up to 17% on OpenWebText and LM1B benchmarks, incurs only a marginal increase in generation perplexity, and preserves sample diversity.
📝 Abstract
Recent advances in masked diffusion language models (MDLMs) narrow the quality gap to autoregressive LMs, but their sampling remains expensive because generation requires many full-sequence denoising passes with a large Transformer and, unlike autoregressive decoding, cannot benefit from KV caching. In this work, we exploit the flexibility of the diffusion framework and study model scheduling, where a smaller MDLM replaces the full model at a subset of denoising steps. Across models trained on OpenWebText and LM1B, we show that early and late denoising steps are substantially more robust to such replacement than middle steps, enabling up to a 17% reduction in FLOPs with only modest degradation in generative perplexity under both unconditional and prefix-conditional generation, while preserving sample diversity. We support these findings with a step-importance analysis based on loss and KL divergence between small and large models across timesteps, as well as an exhaustive search over coarse step segments, both of which identify the middle of the diffusion trajectory as most sensitive consistently across datasets. Our results suggest that simple, architecture-agnostic scheduling rules can significantly accelerate MDLM sampling while largely preserving generation quality.