wd1: Weighted Policy Optimization for Reasoning in Diffusion Language Models

📅 2025-07-07
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Diffusion-based large language models (dLLMs) suffer from high computational overhead and accumulated bias in reinforcement learning (RL)-driven inference optimization, primarily due to repeated approximations of the current, old, and reference policy likelihoods—introducing importance sampling denominator errors and multi-stage approximation bias. Method: We propose Weighted Policy Optimization (WPO), which reformulates the RL objective as a weighted likelihood maximization problem. WPO requires only a single approximation of the current policy likelihood, eliminating dependence on importance sampling and avoiding supervised fine-tuning. It integrates denoising diffusion, zeroth-order optimization, and the R1-Zero training paradigm for end-to-end inference optimization. Contribution/Results: On major reasoning benchmarks, WPO achieves up to a 16% absolute accuracy gain. It significantly reduces the number of gradient-step function evaluations, while improving both training efficiency and stability.

Technology Category

Application Category

📝 Abstract
Improving the reasoning capabilities of diffusion-based large language models (dLLMs) through reinforcement learning (RL) remains an open problem. The intractability of dLLMs likelihood function necessitates approximating the current, old, and reference policy likelihoods at each policy optimization step. This reliance introduces additional computational overhead and lead to potentially large bias -- particularly when approximation errors occur in the denominator of policy ratios used for importance sampling. To mitigate these issues, we introduce $mathtt{wd1}$, a novel policy optimization approach that reformulates the objective as a weighted likelihood, requiring only a single approximation for the current parametrized policy likelihood. Experiments on widely used reasoning benchmarks demonstrate that $mathtt{wd1}$, without supervised fine-tuning (SFT) or any supervised data, outperforms existing RL methods for dLLMs, achieving up to 16% higher accuracy. $mathtt{wd1}$ delivers additional computational gains, including reduced training time and fewer function evaluations (NFEs) per gradient step. These findings, combined with the simplicity of method's implementation and R1-Zero-like training (no SFT), position $mathtt{wd1}$ as a more effective and efficient method for applying RL to dLLMs reasoning.
Problem

Research questions and friction points this paper is trying to address.

Enhancing reasoning in diffusion language models via reinforcement learning
Reducing computational overhead in policy likelihood approximation
Improving accuracy and efficiency in RL for dLLMs
Innovation

Methods, ideas, or system contributions that make the work stand out.

Weighted likelihood reformulation for policy optimization
Single approximation for current policy likelihood
Reduced training time and function evaluations
🔎 Similar Papers
No similar papers found.