Learned Relay Representations for Forward-Thinking Discrete Diffusion Models

📅 2026-05-21
📈 Citations: 0
Influential: 0
📄 PDF

career value

225K/year
🤖 AI Summary
Masked diffusion language models suffer from efficiency bottlenecks due to redundant recomputation of intermediate representations during iterative denoising. This work proposes Learned Relay Representations (Relay), the first differentiable, explicit cross-step information relaying mechanism for discrete diffusion language models. Relay propagates latent states token-by-token across decoding steps and integrates truncated backpropagation through time (BPTT) training, endowing the model with lookahead-aware denoising capabilities. The approach is compatible with both block-wise diffusion and key-value caching, achieving superior code generation performance over standard supervised fine-tuning on Fast-dLLM v2 while reducing inference latency by up to 32%, thereby transcending the conventional trade-off between performance and speed.
📝 Abstract
When Masked Diffusion Models (MDMs) generate sequences through iterative refinement, the rich internal computation over masked positions is discarded, forcing every subsequent refinement step to recompute the valuable internal information stored as model representations. To avoid a hard reset between denoising rounds, we propose Learned Relay Representations (Relay), a method that allows MDMs to be forward-thinking when denoising by explicitly learning how to propagate latent information for the benefit of future denoising steps. Relay introduces a differentiable per-token channel that passes information between forward passes and is trained via truncated backpropagation through time (BPTT). We show that this framework can be scaled to state-of-the-art Diffusion Language Models (DLMs), and is seamlessly compatible with techniques like block diffusion and KV caching. We first provide a thorough justification of the design choices in Relay on a challenging Sudoku-based planning task. We then scale Relay to Fast-dLLM v2, a state-of-the-art DLM, outperforming standard supervised finetuning on coding tasks while reducing inference latency by up to 32%. Our empirical results demonstrate that state-of-the-art DLMs can be explicitly trained to relay latent information forward across decoding steps, advancing the performance-latency Pareto frontier. We provide code for all our experiments.
Problem

Research questions and friction points this paper is trying to address.

Masked Diffusion Models
iterative refinement
latent information propagation
denoising steps
inference latency
Innovation

Methods, ideas, or system contributions that make the work stand out.

Learned Relay Representations
Masked Diffusion Models
Forward-Thinking Denoising
Truncated BPTT
Diffusion Language Models
🔎 Similar Papers
2024-04-19Neural Information Processing SystemsCitations: 14
2024-07-16arXiv.orgCitations: 2