Scaling up Masked Diffusion Models on Text

πŸ“… 2024-10-24
πŸ›οΈ arXiv.org
πŸ“ˆ Citations: 6
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
The scalability and effectiveness of Masked Diffusion Models (MDMs) for language tasks remain poorly understood. Method: We systematically investigate scaling laws of MDMs for text generation and understanding, establishing the first MDM scaling law and training a model family up to 1.1B parameters. We propose an unsupervised classifier-free guidance scheme that overcomes the β€œreverse curse” inherent in autoregressive models, enabling bidirectional inference and temporal adaptation to input sequences. Additionally, we integrate KV-caching for accelerated sampling and leverage large-scale unpaired data for training. Results: Our 1.1B MDM achieves superior zero-shot understanding performance on 4 out of 8 benchmarks compared to a 1.1B TinyLlama trained on identical data; matches the GSM8K mathematical reasoning performance of 7B Llama-2; and attains 1.4Γ— faster sampling than autoregressive models (ARMs). These results demonstrate that MDMs scale effectively across diverse language tasks while offering computational and architectural advantages over conventional autoregressive approaches.

Technology Category

Application Category

πŸ“ Abstract
Masked diffusion models (MDMs) have shown promise in language modeling, yet their scalability and effectiveness in core language tasks, such as text generation and language understanding, remain underexplored. This paper establishes the first scaling law for MDMs, demonstrating a scaling rate comparable to autoregressive models (ARMs) and a relatively small compute gap. Motivated by their scalability, we train a family of MDMs with up to 1.1 billion (B) parameters to systematically evaluate their performance against ARMs of comparable or larger sizes. Fully leveraging the probabilistic formulation of MDMs, we propose a simple yet effective unsupervised classifier-free guidance that effectively exploits large-scale unpaired data, boosting performance for conditional inference. In language understanding, the 1.1B MDM outperforms the 1.1B TinyLlama model trained on the same data across four of eight zero-shot benchmarks. Notably, it achieves competitive math reasoning ability with the 7B Llama-2 model on the GSM8K dataset. In text generation, MDMs with 16 times more pre-training time offer a flexible trade-off against ARMs with the accelerated sampling technique KV-Cache: MDMs match ARMs in performance while being 1.4 times faster during sampling. Moreover, MDMs address challenging tasks for ARMs by effectively handling bidirectional reasoning and adapting to temporal shifts in data. Notably, a 1.1B MDM breaks the reverse curse encountered by much larger ARMs with significantly more data and computation, such as 13B Llama-2 and 175B GPT-3. Our code is available at https://github.com/ML-GSAI/SMDM.
Problem

Research questions and friction points this paper is trying to address.

Scaling masked diffusion models for language tasks
Evaluating MDMs against autoregressive models in performance
Handling bidirectional reasoning and temporal data shifts
Innovation

Methods, ideas, or system contributions that make the work stand out.

Scaling masked diffusion models to 1.1B parameters
Unsupervised classifier-free guidance using unpaired data
KV-Cache accelerated sampling for faster generation
πŸ”Ž Similar Papers
No similar papers found.
Shen Nie
Shen Nie
Renmin University of China
generative model
F
Fengqi Zhu
Gaoling School of Artificial Intelligence, Renmin University of China; Beijing Key Laboratory of Big Data Management and Analysis Methods
C
Chao Du
Sea AI Lab, Singapore
T
Tianyu Pang
Sea AI Lab, Singapore
Q
Qian Liu
Sea AI Lab, Singapore
Guangtao Zeng
Guangtao Zeng
SUTD
natural language processing
Min Lin
Min Lin
Principal Research Scientist, Sea AI Lab
Artificial Intelligence
Chongxuan Li
Chongxuan Li
Associate Professor, Renmin University of China
Machine LearningGenerative ModelsDeep Learning