🤖 AI Summary
This work addresses the challenge of balancing parallel generation efficiency and sample quality in masked discrete diffusion models by introducing, for the first time, a variational inference perspective to model a learnable parallel generation order. The proposed framework employs a training-efficient parameterization of the approximate posterior, enabling significantly accelerated parallel sampling while preserving high generation quality. Experimental results on the GSM8K dataset demonstrate that the method achieves an accuracy of 33.1% with an average of only four generation steps, substantially outperforming existing approaches that report accuracies in the range of 23.7–29.0%. These findings confirm the dual advantage of the proposed approach in both computational efficiency and generation performance.
📝 Abstract
Masked discrete diffusion models (MDMs) are a promising new approach to generative modelling, offering the ability for parallel token generation and therefore greater efficiency than autoregressive counterparts. However, achieving an optimal balance between parallel generation and sample quality remains an open problem. Current approaches primarily address this issue through fixed, heuristic parallel sampling methods. There exist some recent learning based approaches to this problem, but its formulation from the perspective of variational inference remains underexplored. In this work, we propose a variational inference framework for learning parallel generation orders for MDMs. As part of our method, we propose a parameterisation for the approximate posterior of generation orders which facilitates parallelism and efficient sampling during training. Using this method, we conduct preliminary experiments on the GSM8K dataset, where our method performs competitively against heuristic sampling strategies in the regime of highly parallel generation. For example, our method achieves 33.1\% accuracy with an average of only only 4 generation steps, compared to 23.7-29.0\% accuracy achieved by standard competitor methods in the same number of steps. We believe further experiments and analysis of the method will yield valuable insights into the problem of parallel generation with MDMs.