Learning Generation Orders for Masked Discrete Diffusion Models via Variational Inference

📅 2026-02-27
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the challenge of balancing parallel generation efficiency and sample quality in masked discrete diffusion models by introducing, for the first time, a variational inference perspective to model a learnable parallel generation order. The proposed framework employs a training-efficient parameterization of the approximate posterior, enabling significantly accelerated parallel sampling while preserving high generation quality. Experimental results on the GSM8K dataset demonstrate that the method achieves an accuracy of 33.1% with an average of only four generation steps, substantially outperforming existing approaches that report accuracies in the range of 23.7–29.0%. These findings confirm the dual advantage of the proposed approach in both computational efficiency and generation performance.

Technology Category

Application Category

📝 Abstract
Masked discrete diffusion models (MDMs) are a promising new approach to generative modelling, offering the ability for parallel token generation and therefore greater efficiency than autoregressive counterparts. However, achieving an optimal balance between parallel generation and sample quality remains an open problem. Current approaches primarily address this issue through fixed, heuristic parallel sampling methods. There exist some recent learning based approaches to this problem, but its formulation from the perspective of variational inference remains underexplored. In this work, we propose a variational inference framework for learning parallel generation orders for MDMs. As part of our method, we propose a parameterisation for the approximate posterior of generation orders which facilitates parallelism and efficient sampling during training. Using this method, we conduct preliminary experiments on the GSM8K dataset, where our method performs competitively against heuristic sampling strategies in the regime of highly parallel generation. For example, our method achieves 33.1\% accuracy with an average of only only 4 generation steps, compared to 23.7-29.0\% accuracy achieved by standard competitor methods in the same number of steps. We believe further experiments and analysis of the method will yield valuable insights into the problem of parallel generation with MDMs.
Problem

Research questions and friction points this paper is trying to address.

masked discrete diffusion models
parallel generation
generation order
sample quality
variational inference
Innovation

Methods, ideas, or system contributions that make the work stand out.

masked discrete diffusion models
variational inference
parallel generation
generation order learning
approximate posterior
🔎 Similar Papers
No similar papers found.