π€ AI Summary
The scalability and effectiveness of Masked Diffusion Models (MDMs) for language tasks remain poorly understood.
Method: We systematically investigate scaling laws of MDMs for text generation and understanding, establishing the first MDM scaling law and training a model family up to 1.1B parameters. We propose an unsupervised classifier-free guidance scheme that overcomes the βreverse curseβ inherent in autoregressive models, enabling bidirectional inference and temporal adaptation to input sequences. Additionally, we integrate KV-caching for accelerated sampling and leverage large-scale unpaired data for training.
Results: Our 1.1B MDM achieves superior zero-shot understanding performance on 4 out of 8 benchmarks compared to a 1.1B TinyLlama trained on identical data; matches the GSM8K mathematical reasoning performance of 7B Llama-2; and attains 1.4Γ faster sampling than autoregressive models (ARMs). These results demonstrate that MDMs scale effectively across diverse language tasks while offering computational and architectural advantages over conventional autoregressive approaches.
π Abstract
Masked diffusion models (MDMs) have shown promise in language modeling, yet their scalability and effectiveness in core language tasks, such as text generation and language understanding, remain underexplored. This paper establishes the first scaling law for MDMs, demonstrating a scaling rate comparable to autoregressive models (ARMs) and a relatively small compute gap. Motivated by their scalability, we train a family of MDMs with up to 1.1 billion (B) parameters to systematically evaluate their performance against ARMs of comparable or larger sizes. Fully leveraging the probabilistic formulation of MDMs, we propose a simple yet effective unsupervised classifier-free guidance that effectively exploits large-scale unpaired data, boosting performance for conditional inference. In language understanding, the 1.1B MDM outperforms the 1.1B TinyLlama model trained on the same data across four of eight zero-shot benchmarks. Notably, it achieves competitive math reasoning ability with the 7B Llama-2 model on the GSM8K dataset. In text generation, MDMs with 16 times more pre-training time offer a flexible trade-off against ARMs with the accelerated sampling technique KV-Cache: MDMs match ARMs in performance while being 1.4 times faster during sampling. Moreover, MDMs address challenging tasks for ARMs by effectively handling bidirectional reasoning and adapting to temporal shifts in data. Notably, a 1.1B MDM breaks the reverse curse encountered by much larger ARMs with significantly more data and computation, such as 13B Llama-2 and 175B GPT-3. Our code is available at https://github.com/ML-GSAI/SMDM.