Diffusion Tree Sampling: Scalable inference-time alignment of diffusion models

📅 2025-06-25
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses zero-shot adaptation during diffusion model inference without retraining. We formulate inference-time alignment as a tree search problem and introduce Monte Carlo Tree Search (MCTS) into diffusion sampling—first of its kind—proposing Diffusion Tree Sampling (DTS) and its greedy variant DTS$^star$. Our method leverages tree-structured reuse of historical denoising trajectories, backpropagates terminal rewards, and iteratively refines value estimates, enabling asymptotically exact sampling and continuous optimization even when interrupted at arbitrary timesteps. On MNIST and CIFAR-10, DTS achieves superior FID with up to 10× lower computational cost than state-of-the-art methods. In text-to-image and language generation tasks, it matches best-of-N performance using only 20% of the computation. Crucially, this work establishes the first theoretical connection between diffusion inference and scalable sequential decision-making, bridging generative modeling with principled planning frameworks.

Technology Category

Application Category

📝 Abstract
Adapting a pretrained diffusion model to new objectives at inference time remains an open problem in generative modeling. Existing steering methods suffer from inaccurate value estimation, especially at high noise levels, which biases guidance. Moreover, information from past runs is not reused to improve sample quality, resulting in inefficient use of compute. Inspired by the success of Monte Carlo Tree Search, we address these limitations by casting inference-time alignment as a search problem that reuses past computations. We introduce a tree-based approach that samples from the reward-aligned target density by propagating terminal rewards back through the diffusion chain and iteratively refining value estimates with each additional generation. Our proposed method, Diffusion Tree Sampling (DTS), produces asymptotically exact samples from the target distribution in the limit of infinite rollouts, and its greedy variant, Diffusion Tree Search (DTS$^star$), performs a global search for high reward samples. On MNIST and CIFAR-10 class-conditional generation, DTS matches the FID of the best-performing baseline with up to $10 imes$ less compute. In text-to-image generation and language completion tasks, DTS$^star$ effectively searches for high reward samples that match best-of-N with up to $5 imes$ less compute. By reusing information from previous generations, we get an anytime algorithm that turns additional compute into steadily better samples, providing a scalable approach for inference-time alignment of diffusion models.
Problem

Research questions and friction points this paper is trying to address.

Adapting pretrained diffusion models to new inference objectives
Improving sample quality by reusing past generation information
Reducing computational costs in reward-aligned diffusion sampling
Innovation

Methods, ideas, or system contributions that make the work stand out.

Tree-based approach for reward-aligned sampling
Reuses past computations to improve efficiency
Global search for high reward samples
🔎 Similar Papers
2022-09-02ACM Computing SurveysCitations: 1628