🤖 AI Summary
This work addresses the lack of efficient and accurate unmasking strategies in masked diffusion language models during inference, where existing approaches rely either on heuristic rules or costly reinforcement learning. The authors propose Gt-Margin, a ground-truth-based per-position confidence metric, to derive an optimal unmasking sequence. Leveraging this metric, they train a supervised unmasking planner via learning-to-rank, enabling decoupled optimization without modifying the underlying generative model. Experimental results demonstrate that the proposed method significantly improves generation accuracy on logical reasoning tasks, validating both the effectiveness and generalizability of the unmasking strategy.
📝 Abstract
Masked Diffusion Language Models (MDLMs) generate text by iteratively filling masked tokens, requiring two coupled decisions at each step: which positions to unmask (where-to-unmask) and which tokens to place (what-to-unmask). While standard MDLM training directly optimizes token prediction (what-to-unmask), inference-time unmasking orders (where-to-unmask) are typically determined by heuristic confidence measures or trained through reinforcement learning with costly on-policy rollouts. To address this, we introduce Gt-Margin, a position-wise score derived from ground-truth tokens, defined as the probability margin between the correct token and its strongest alternative. Gt-Margin yields an oracle unmasking order that prioritizes easier positions first under each partially masked state. We demonstrate that leveraging this oracle unmasking order significantly enhances final generation quality, particularly on logical reasoning benchmarks. Building on this insight, we train a supervised unmasking planner via learning-to-rank to imitate the oracle ordering from masked contexts. The resulting planner integrates into standard MDLM sampling to select where-to-unmask, improving reasoning accuracy without modifying the token prediction model.