🤖 AI Summary
Parallel text generation models typically underperform autoregressive counterparts due to the difficulty of modeling the complex joint distribution over token sequences. This work proposes Gumbel Distillation, a novel approach that introduces the Gumbel-Max trick into knowledge distillation for the first time, establishing a deterministic mapping from an implicit Gumbel noise space to the outputs of a high-performing autoregressive teacher model. This enables parallel decoders to effectively learn the sequence-level joint distribution without relying on any specific architectural assumptions, making it applicable to various non-autoregressive models such as MDLM and BD3-LM. Experiments on LM1B and OpenWebText demonstrate substantial improvements, including a 30.0% increase in MAUVE score and a 10.5% reduction in generation perplexity, significantly narrowing the performance gap with autoregressive models.
📝 Abstract
The slow, sequential nature of autoregressive (AR) language models has driven the adoption of parallel decoding methods. However, these non-AR models often sacrifice generation quality as they struggle to model the complex joint distribution of token sequences. To narrow this performance gap, we introduce Gumbel Distillation, a novel distillation technique that enables parallel decoders to learn this distribution effectively. Our method leverages the Gumbel-Max trick to create a deterministic mapping from a latent Gumbel noise space to the output tokens of a high-performing AR teacher. As a model-agnostic technique, Gumbel Distillation seamlessly integrates with diverse parallel decoding architectures, including MDLM and BD3-LM. Experiments on LM1B and OpenWebText show that Gumbel Distillation substantially improves the generation quality of parallel language models, achieving a 30.0% improvement in MAUVE score and 10.5% in generative perplexity over MDLM trained on OpenWebText dataset. Code available at https://github.com/hxixixh/gumbel-distill.