π€ AI Summary
This work addresses the lack of self-correction capability during inference in masked diffusion models (MDMs). We propose PRISMβa lightweight, model-agnostic, inference-time self-correction method. PRISM requires no architectural modification, retraining, reinforcement learning, or external verifiers. It introduces, for the first time, a provably convergent self-correction loss function that directly learns token-level quality scores from generated trajectories. By estimating token quality via forward passes and dynamically re-masking low-quality tokens, PRISM enables iterative refinement. Evaluated on Sudoku solving, unconditional text generation with a 170M-parameter model, and LLaDA-8B code generation, PRISM consistently improves generation quality and correction accuracy. Results demonstrate its effectiveness, robustness, and broad generalizability across diverse generative tasks and model scales.
π Abstract
A natural desideratum for generative models is self-correction--detecting and revising low-quality tokens at inference. While Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces, their capacity for self-correction remains poorly understood. Prior attempts to incorporate self-correction into MDMs either require overhauling MDM architectures/training or rely on imprecise proxies for token quality, limiting their applicability. Motivated by this, we introduce PRISM--Plug-in Remasking for Inference-time Self-correction of Masked Diffusions--a lightweight, model-agnostic approach that applies to any pretrained MDM. Theoretically, PRISM defines a self-correction loss that provably learns per-token quality scores, without RL or a verifier. These quality scores are computed in the same forward pass with MDM and used to detect low-quality tokens. Empirically, PRISM advances MDM inference across domains and scales: Sudoku; unconditional text (170M); and code with LLaDA (8B).