🤖 AI Summary
This work addresses the inefficiency and distributional mismatch in Masked Diffusion Models (MDMs), where training with excessive random masking incurs high computational costs and diverges from the structured masking used during inference. To bridge this gap, the authors propose Progressive Unmasking via Mask Alignment (PUMA), a novel approach that adaptively reshapes the forward masking process to align the training-time mask distribution with that of inference, thereby emphasizing effective masking patterns. PUMA is the first method to achieve consistency between training and inference masking strategies, substantially reducing redundant computation and accelerating convergence while remaining compatible with techniques such as autoregressive initialization. Experiments on a 125M-parameter model demonstrate that PUMA speeds up pretraining by approximately 2.5× without compromising—and in some cases even improving—generation quality.
📝 Abstract
Masked Diffusion Models (MDMs) have emerged as a promising approach for generative modeling in discrete spaces. By generating sequences in any order and allowing for parallel decoding, they enable fast inference and strong performance on non-causal tasks. However, this flexibility comes with a training complexity trade-off: MDMs train on an exponentially large set of masking patterns, which is not only computationally expensive, but also creates a train--test mismatch between the random masks used in training and the highly structured masks induced by inference-time unmasking. In this work, we propose Progressive UnMAsking (PUMA), a simple modification of the forward masking process that aligns training-time and inference-time masking patterns, thereby focusing optimization on inference-aligned masks and speeding up training. Empirically, PUMA speeds up pretraining at the 125M scale by $\approx 2.5\times$ and offers complementary advantages on top of common recipes like autoregressive initialization. We open-source our codebase at https://github.com/JaeyeonKim01/PUMA.