🤖 AI Summary
This work addresses the high computational cost and instability of conventional approaches to aligning diffusion models with human preferences, which typically require multiple training runs to search for an optimal KL regularization strength—often leading to under-alignment or reward hacking when improperly tuned. The authors propose a denoising-time realignment mechanism that, after a single alignment training run, dynamically adjusts the regularization strength during sampling to emulate the effects of different strengths without additional training. Extending decoding-time realignment from language models to diffusion models for the first time, the method derives a closed-form update rule based on geometric mixing of reference and aligned posterior distributions, enabling real-time control in the continuous latent space via a single tunable parameter λ. Experiments demonstrate that this approach accurately approximates the performance of fully retrained models across multiple text-to-image alignment and image quality metrics while substantially reducing computational overhead.
📝 Abstract
Recent advances align diffusion models with human preferences to increase aesthetic appeal and mitigate artifacts and biases. Such methods aim to maximize a conditional output distribution aligned with higher rewards whilst not drifting far from a pretrained prior. This is commonly enforced by KL (Kullback Leibler) regularization. As such, a central issue still remains: how does one choose the right regularization strength? Too high of a strength leads to limited alignment and too low of a strength leads to"reward hacking". This renders the task of choosing the correct regularization strength highly non-trivial. Existing approaches sweep over this hyperparameter by aligning a pretrained model at multiple regularization strengths and then choose the best strength. Unfortunately, this is prohibitively expensive. We introduce DeRaDiff, a denoising time realignment procedure that, after aligning a pretrained model once, modulates the regularization strength during sampling to emulate models trained at other regularization strengths without any additional training or finetuning. Extending decoding-time realignment from language to diffusion models, DeRaDiff operates over iterative predictions of continuous latents by replacing the reverse step reference distribution by a geometric mixture of an aligned and reference posterior, thus giving rise to a closed form update under common schedulers and a single tunable parameter, lambda, for on the fly control. Our experiments show that across multiple text image alignment and image-quality metrics, our method consistently provides a strong approximation for models aligned entirely from scratch at different regularization strengths. Thus, our method yields an efficient way to search for the optimal strength, eliminating the need for expensive alignment sweeps and thereby substantially reducing computational costs.