Tree Reward-Aligned Search for TReASURe in Masked Diffusion Language Models

📅 2025-09-27
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Masked diffusion language models (MDLMs) face two key challenges in tree search: high inter-branch correlation due to parallel unmasking, and high-variance reward estimation caused by stochastic sampling-based completion. To address these, we propose TReASURe—a tree search framework for MDLMs. Its core contributions are: (1) UnmaskBranch, a branching strategy that generates diverse candidate branches—including both content and unmasking order—in a single model call; and (2) ResubstituteScore, a pruning rule leveraging deterministic resubstitution for low-variance reward estimation, with theoretical error bounds tied to prediction uncertainty. Additionally, TReASURe integrates first-hit unmasking, parallel unmasking optimization, and proxy completion to enable efficient inference with low function evaluation (NFE) counts. Experiments demonstrate state-of-the-art performance across perplexity, linguistic acceptability, and controllable generation (e.g., sentiment/toxicity control), significantly outperforming prior methods—especially under stringent NFE budgets.

Technology Category

Application Category

📝 Abstract
Tree search has recently emerged as a powerful framework for aligning generative models with task-specific rewards at test time. Applying tree search to Masked Diffusion Language Models, however, introduces two key challenges: (i) parallel unmasking yields highly correlated branches, limiting exploration, and (ii) reward evaluation via sampled completions produces high-variance estimates, making pruning unstable. We propose TReASURe, a tree-search test-time alignment method that addresses these issues. It introduces (i) UnmaskBranch, a branching strategy based on first-hitting unmasking that diversifies both token content and reveal order with a single model call per parent node, and (ii) ResubstituteScore, a pruning rule that uses deterministic resubstitution to score partially masked sequences with low-variance proxy completions. Theoretically, we quantify branching efficiency gains in NFEs (number of function evaluations), show that the scoring rule approximates the true reward with error bounded by predictive uncertainty, and prove improvements with larger tree widths. Empirically, TReASURe achieves state-of-the-art results on perplexity, linguistic acceptability, and control of sentiment and toxicity, outperforming prior methods under matched compute budgets, with especially strong gains in low-NFE regimes.
Problem

Research questions and friction points this paper is trying to address.

Addresses correlated branching in masked diffusion language models
Reduces high-variance reward estimates during tree search pruning
Improves alignment efficiency for text generation tasks
Innovation

Methods, ideas, or system contributions that make the work stand out.

UnmaskBranch diversifies token content and reveal order
ResubstituteScore uses deterministic resubstitution for low-variance scoring
Tree search method improves alignment with task-specific rewards
Z
Zichao Yu
University of Science and Technology of China
M
Ming Li
Fudan University
W
Wenyi Zhang
University of Science and Technology of China
Weiguo Gao
Weiguo Gao
Beijing University of Posts and Telecommunications
natural language processing