KLASS: KL-Guided Fast Inference in Masked Diffusion Models

📅 2025-11-07
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Masked diffusion models suffer from slow inference and static sampling, limiting their practical deployment. To address this, we propose KLASS, the first method to employ token-level KL divergence as a confidence metric for dynamic, adaptive multi-token parallel decoding. Without requiring additional training, KLASS dynamically identifies high-confidence tokens in real time based on KL values and simultaneously updates both the masking pattern and predictions. Integrating KL-driven adaptive scheduling with synchronized decoding, KLASS is inherently compatible with diverse generation tasks—including text, image, and molecular synthesis. Extensive experiments demonstrate that KLASS achieves a 2.78× speedup over standard sequential sampling while preserving generation quality—outperforming greedy decoding and establishing new state-of-the-art efficiency among diffusion samplers.

Technology Category

Application Category

📝 Abstract
Masked diffusion models have demonstrated competitive results on various tasks including language generation. However, due to its iterative refinement process, the inference is often bottlenecked by slow and static sampling speed. To overcome this problem, we introduce `KL-Adaptive Stability Sampling'(KLASS), a fast yet effective sampling method that exploits token-level KL divergence to identify stable, high-confidence predictions. By unmasking multiple tokens in each iteration without any additional model training, our approach speeds up generation significantly while maintaining sample quality. On reasoning benchmarks, KLASS achieves up to $2.78 imes$ wall-clock speedups while improving performance over standard greedy decoding, attaining state-of-the-art results among diffusion-based samplers. We further validate KLASS across diverse domains, including text, image, and molecular generation, showing its effectiveness as a broadly applicable sampler across different models.
Problem

Research questions and friction points this paper is trying to address.

Accelerates slow inference in masked diffusion models
Maintains quality while unmasking multiple tokens per iteration
Applies KL divergence to identify stable predictions across domains
Innovation

Methods, ideas, or system contributions that make the work stand out.

KL-guided adaptive stability sampling for diffusion
Unmasking multiple tokens per iteration without retraining
Accelerates generation while maintaining sample quality
🔎 Similar Papers
No similar papers found.