Bringing Stability to Diffusion: Decomposing and Reducing Variance of Training Masked Diffusion Models

📅 2025-11-22
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Masked Diffusion Models (MDMs) suffer from high gradient variance during training, leading to optimization instability and significantly inferior downstream performance compared to Autoregressive Models (ARMs). This work presents the first theoretical decomposition of MDM training variance, identifying three primary sources: masking patterns, mask ratio, and data distribution. Building on this analysis, we propose six variance-reduction techniques. Our core innovations are the Pareto-optimal t-step sampler (P-POTS) and the negative-correlation sample construction method (MIRROR), which jointly suppress gradient noise and enhance training stability. Experiments demonstrate that our approach improves accuracy by 7–8% on complex reasoning tasks, reduces run-to-run variance to levels comparable with ARMs, and ensures that even the worst-case performance surpasses the current state-of-the-art baseline.

Technology Category

Application Category

📝 Abstract
Masked diffusion models (MDMs) are a promising alternative to autoregressive models (ARMs), but they suffer from inherently much higher training variance. High variance leads to noisier gradient estimates and unstable optimization, so even equally strong pretrained MDMs and ARMs that are competitive at initialization often diverge after task-specific training, with MDMs falling far behind. There has been no theoretical explanation or systematic solution. We derive the first decomposition of MDM training variance into three sources: (A) masking pattern noise, (B) masking rate noise, and (C) data noise, while ARMs are only affected by (C). This explains the fundamental training gap. Building on this foundation, we design six variance-reduction methods, including two core methods: (1) P-POTS, a Pareto-optimal t sampler that minimizes training variance by sampling harder t values more often with appropriately smaller update steps, and (2) MIRROR, which uses negatively correlated samples to reduce (A). Experiments show that compared to standard MDM training, our methods improve accuracy by 7-8% on complex reasoning tasks, while simultaneously reducing run-to-run variability to near ARM levels, substantially narrowing the gap with strong ARM baselines; in most settings, even the best baseline runs remain below the worst run of our method.
Problem

Research questions and friction points this paper is trying to address.

Reducing high training variance in masked diffusion models
Addressing unstable optimization and noisy gradient estimates
Narrowing performance gap with autoregressive models
Innovation

Methods, ideas, or system contributions that make the work stand out.

Decomposed training variance into three noise sources
Introduced P-POTS sampler for optimal time step selection
Used MIRROR method with negative correlation samples
🔎 Similar Papers
No similar papers found.
M
Mengni Jia
University of Cambridge
Mengyu Zhou
Mengyu Zhou
Microsoft Research
Data analyticsNatural Language ProcessingNetwork ScienceHuman BehaviorsMobile & Ubiquitous Computing
Y
Yihao Liu
Peking University
X
Xiaoxi Jiang
Alibaba Group
G
Guanjun Jiang
Alibaba Group