🤖 AI Summary
Discrete diffusion models often suffer from reward degradation and unstable convergence during reward-guided fine-tuning due to suboptimal trajectory sampling. To address this, we propose a trajectory-aware tree search fine-tuning framework. Instead of relying on rollouts from the current policy, our method employs Monte Carlo Tree Search (MCTS) to actively construct a high-reward trajectory replay buffer and updates the discrete diffusion model parameters via a stochastic optimal control objective. Crucially, it integrates deterministic tree search with stochastic generative modeling, enabling high-quality, stable trajectory-guided fine-tuning. Evaluated on single- and multi-objective biomolecular sequence generation tasks, our approach achieves a +12.7% improvement in final reward, accelerates convergence by reducing required iterations by 38%, and enhances both functional validity and structural diversity of generated sequences.
📝 Abstract
Reinforcement learning with stochastic optimal control offers a promising framework for diffusion fine-tuning, where a pre-trained diffusion model is optimized to generate paths that lead to a reward-tilted distribution. While these approaches enable optimization without access to explicit samples from the optimal distribution, they require training on rollouts under the current fine-tuned model, making them susceptible to reinforcing sub-optimal trajectories that yield poor rewards. To overcome this challenge, we introduce TRee Search Guided TRajectory-Aware Fine-Tuning for Discrete Diffusion (TR2-D2), a novel framework that optimizes reward-guided discrete diffusion trajectories with tree search to construct replay buffers for trajectory-aware fine-tuning. These buffers are generated using Monte Carlo Tree Search (MCTS) and subsequently used to fine-tune a pre-trained discrete diffusion model under a stochastic optimal control objective. We validate our framework on single- and multi-objective fine-tuning of biological sequence diffusion models, highlighting the overall effectiveness of TR2-D2 for reliable reward-guided fine-tuning in discrete sequence generation.