🤖 AI Summary
This work investigates the grokking phenomenon in masked diffusion language models—characterized by prolonged training plateaus at random performance followed by abrupt generalization—and attributes it to insufficient inductive bias. Using the $k$-parity problem as a probing task, the study uncovers an implicit decomposition within the training objective into a signal term that drives feature learning and a noise term that acts as implicit regularization. Leveraging this insight, the authors propose an optimized masking probability distribution that enables rapid and synchronous generalization without undergoing grokking. They further introduce a scalable masking schedule derived from this principle. Experiments on the nanoGPT architecture demonstrate substantial improvements: a 50M-parameter model achieves markedly lower perplexity, while an 8B-parameter model shows performance gains of 8.8% and 5.8% during pretraining and fine-tuning, respectively.
📝 Abstract
Masked Diffusion Language Models have recently emerged as a powerful generative paradigm, yet their generalization properties remain understudied compared to their auto-regressive counterparts. In this work, we investigate these properties within the setting of the $k$-parity problem (computing the XOR sum of $k$ relevant bits), where neural networks typically exhibit grokking -- a prolonged plateau of chance-level performance followed by sudden generalization. We theoretically decompose the Masked Diffusion (MD) objective into a Signal regime which drives feature learning, and a Noise regime which serves as an implicit regularizer. By training nanoGPT using MD objective on the $k$-parity problem, we demonstrate that MD objective fundamentally alters the learning landscape, enabling rapid and simultaneous generalization without experiencing grokking. Furthermore, we leverage our theoretical insights to optimize the distribution of the mask probability in the MD objective. Our method significantly improves perplexity for 50M-parameter models and achieves superior results across both pre-training from scratch and supervised fine-tuning. Specifically, we observe performance gains peaking at $8.8\%$ and $5.8\%$, respectively, on 8B-parameter models, confirming the scalability and effectiveness of our framework in large-scale masked diffusion language model regimes.