Train for the Worst, Plan for the Best: Understanding Token Ordering in Masked Diffusions

📅 2025-02-10
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the fundamental tension in Masked Diffusion Models (MDMs) for discrete-domain generation: high training complexity—arising from exponential infilling subproblems—and limited inference flexibility. We provide the first theoretical and empirical analysis revealing the core training bottlenecks of MDMs. To overcome these limitations without retraining, we propose an uncertainty- and entropy-driven adaptive token decoding order strategy, seamlessly applicable to pre-trained MDMs. Evaluated on logic puzzles such as Sudoku, our method achieves ≈90% solving accuracy—substantially outperforming autoregressive baselines under identical settings, including those with seven times more parameters. We further introduce a dedicated evaluation framework for discrete logical reasoning. Our key contributions are threefold: (i) a formal theoretical characterization of MDM training hardness; (ii) methodological decoupling of decoding order from model architecture and training; and (iii) empirical validation that adaptive ordering bypasses inherent training difficulties, yielding significant gains in inference performance.

Technology Category

Application Category

📝 Abstract
In recent years, masked diffusion models (MDMs) have emerged as a promising alternative approach for generative modeling over discrete domains. Compared to autoregressive models (ARMs), MDMs trade off complexity at training time with flexibility at inference time. At training time, they must learn to solve an exponentially large number of infilling problems, but at inference time, they can decode tokens in essentially arbitrary order. In this work, we closely examine these two competing effects. On the training front, we theoretically and empirically demonstrate that MDMs indeed train on computationally intractable subproblems compared to their autoregressive counterparts. On the inference front, we show that a suitable strategy for adaptively choosing the token decoding order significantly enhances the capabilities of MDMs, allowing them to sidestep hard subproblems. On logic puzzles like Sudoku, we show that adaptive inference can boost solving accuracy in pretrained MDMs from $<7$% to $approx 90$%, even outperforming ARMs with $7 imes$ as many parameters and that were explicitly trained via teacher forcing to learn the right order of decoding.
Problem

Research questions and friction points this paper is trying to address.

Examines training complexity in masked diffusion models
Analyzes token decoding order impact on MDMs
Enhances MDM accuracy via adaptive inference strategies
Innovation

Methods, ideas, or system contributions that make the work stand out.

Masked diffusion models
Adaptive token decoding
Enhanced solving accuracy
🔎 Similar Papers
No similar papers found.