Adding Conditional Control to Diffusion Models with Reinforcement Learning

📅 2024-06-17
🏛️ arXiv.org
📈 Citations: 2
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the challenge of fine-grained conditional control in pretrained diffusion models. We propose CTRL, a framework that formulates conditional generation as a reinforcement learning (RL) policy optimization problem. Specifically, CTRL employs a PPO variant with KL regularization to jointly optimize a policy network using offline annotated data; the composite reward comprises classifier outputs and KL divergence from the target conditional distribution. To our knowledge, this is the first approach to cast conditional diffusion generation as end-to-end RL policy learning—eliminating the need for intermediate state classifiers, thereby significantly reducing data requirements and training complexity. Moreover, CTRL enables soft optimal conditional sampling and enjoys theoretical convergence guarantees to the desired conditional distribution. Experiments on image generation demonstrate that CTRL outperforms classifier guidance and classifier-free guidance (CFG) using fewer labeled examples, achieving superior trade-offs between control fidelity and sample diversity.

Technology Category

Application Category

📝 Abstract
Diffusion models are powerful generative models that allow for precise control over the characteristics of the generated samples. While these diffusion models trained on large datasets have achieved success, there is often a need to introduce additional controls in downstream fine-tuning processes, treating these powerful models as pre-trained diffusion models. This work presents a novel method based on reinforcement learning (RL) to add such controls using an offline dataset comprising inputs and labels. We formulate this task as an RL problem, with the classifier learned from the offline dataset and the KL divergence against pre-trained models serving as the reward functions. Our method, $ extbf{CTRL}$ ($ extbf{C}$onditioning pre-$ extbf{T}$rained diffusion models with $ extbf{R}$einforcement $ extbf{L}$earning), produces soft-optimal policies that maximize the abovementioned reward functions. We formally demonstrate that our method enables sampling from the conditional distribution with additional controls during inference. Our RL-based approach offers several advantages over existing methods. Compared to classifier-free guidance, it improves sample efficiency and can greatly simplify dataset construction by leveraging conditional independence between the inputs and additional controls. Additionally, unlike classifier guidance, it eliminates the need to train classifiers from intermediate states to additional controls. The code is available at https://github.com/zhaoyl18/CTRL.
Problem

Research questions and friction points this paper is trying to address.

Enhancing control in diffusion models
Using reinforcement learning for fine-tuning
Improving sample efficiency and dataset construction
Innovation

Methods, ideas, or system contributions that make the work stand out.

Reinforcement Learning controls diffusion models
KL divergence as reward in RL
Soft-optimal policies maximize reward functions
🔎 Similar Papers
No similar papers found.