π€ AI Summary
This work addresses the high computational and memory costs of discrete diffusion language models, which stem from their full-vocabulary prediction layers that hinder efficient training under resource constraints. The authors propose the first integration of a lexical tree into this framework, enabling tree-factorized prediction during the diffusion process by representing token ancestors through hidden states. This approach exponentially reduces the classification dimensionality, rendering the prediction headβs parameter count negligible and freeing up resources to enhance the attention modules. Under an identical parameter budget, the method achieves a 50% reduction in peak GPU memory usage while matching the state-of-the-art perplexity performance of existing discrete diffusion language models.
π Abstract
Discrete diffusion language models have emerged as a competitive alternative to auto-regressive language models, but training them efficiently under limited parameter and memory budgets remains challenging. Modern architectures are predominantly based on a full-vocabulary token prediction layer, which accounts for a substantial fraction of model parameters (e.g., more than 20% in small scale DiT-style designs) and often dominates peak GPU memory usage. This leads to inefficient use of both parameters and memory under constrained training resources. To address this issue, we revisit the necessity of explicit full-vocabulary prediction, and instead exploit the inherent structure among tokens to build a tree-structured diffusion language model. Specifically, we model the diffusion process with intermediate latent states corresponding to a token's ancestor nodes in a pre-constructed vocabulary tree. This tree-structured factorization exponentially reduces the classification dimensionality, makes the prediction head negligible in size, and enables reallocation of parameters to deepen the attention blocks. Empirically, under the same parameter budget, our method reduces peak GPU memory usage by half while matching the perplexity performance of state-of-the-art discrete diffusion language models.