🤖 AI Summary
Masked Diffusion Models (MDMs) suffer from redundant computation in discrete sequence generation due to binary masking, which causes tokens to remain unchanged across many sampling steps. To address this, we propose Partial Masking (Prime), the first framework to introduce continuous-interpolated intermediate mask states into discrete diffusion, enabling token-level fine-grained denoising and overcoming the rigid all-or-nothing masking paradigm. Methodologically, we formulate a variational training objective and design a dedicated architecture that eliminates reliance on autoregressive structures. Experiments demonstrate state-of-the-art performance: perplexity of 15.36 on OpenWebText for text generation, and FID scores of 3.26 (CIFAR-10) and 6.98 (ImageNet-32) for image generation—surpassing existing MDMs and hybrid models. Our core contribution is a differentiable intermediate masking mechanism that unifies discrete token representation with continuous denoising dynamics.
📝 Abstract
Masked diffusion models (MDM) are powerful generative models for discrete data that generate samples by progressively unmasking tokens in a sequence. Each token can take one of two states: masked or unmasked. We observe that token sequences often remain unchanged between consecutive sampling steps; consequently, the model repeatedly processes identical inputs, leading to redundant computation. To address this inefficiency, we propose the Partial masking scheme (Prime), which augments MDM by allowing tokens to take intermediate states interpolated between the masked and unmasked states. This design enables the model to make predictions based on partially observed token information, and facilitates a fine-grained denoising process. We derive a variational training objective and introduce a simple architectural design to accommodate intermediate-state inputs. Our method demonstrates superior performance across a diverse set of generative modeling tasks. On text data, it achieves a perplexity of 15.36 on OpenWebText, outperforming previous MDM (21.52), autoregressive models (17.54), and their hybrid variants (17.58), without relying on an autoregressive formulation. On image data, it attains competitive FID scores of 3.26 on CIFAR-10 and 6.98 on ImageNet-32, comparable to leading continuous generative models.