🤖 AI Summary
This work addresses the significant mismatch between training and inference in diffusion language models (DLMs), where training typically relies on static single-step masked prediction while inference involves multi-step progressive denoising. To bridge this gap, the authors propose a bilevel optimization framework that explicitly simulates the denoising process during training. The inner loop constructs parameterized memory via fast weights to capture local denoising trajectories of individual samples, while the outer loop updates the main model parameters based on this memory. This design shifts the memory burden from token representations to model parameters, enabling adaptive step adjustment and implicit retrieval of relevant information during inference. Experiments demonstrate that the approach accelerates convergence, reduces training loss, and achieves substantial improvements over baselines on long-context understanding and Needle-in-a-Haystack retrieval tasks.
📝 Abstract
Diffusion Language Models (DLMs) offer attractive advantages over Auto-Regressive (AR) models, such as full-attention parallel decoding and flexible generation. However, they suffer from a notable train-inference mismatch: DLMs are trained with a static, single-step masked prediction objective, but deployed through a multi-step progressive denoising trajectory. We propose MemDLM (Memory-Enhanced DLM), which narrows this gap by embedding a simulated denoising process into training via Bi-level Optimization. An inner loop updates a set of fast weights, forming a Parametric Memory that captures the local trajectory experience of each sample, while an outer loop updates the base model conditioned on this memory. By offloading memorization pressure from token representations to parameters, MemDLM yields faster convergence and lower training loss. Moreover, the inner loop can be re-enabled at inference time as an adaptation step, yielding additional gains on long-context understanding. We find that, when activated at inference time, this Parametric Memory acts as an emergent in-weight retrieval mechanism, helping MemDLM further reduce token-level attention bottlenecks on challenging Needle-in-a-Haystack retrieval tasks. Code: https://github.com/JarvisPei/MemDLM.