🤖 AI Summary
GRPO faces two key bottlenecks in diffusion model alignment: high computational cost—due to online policy rollouts and extensive SDE sampling—and training instability—caused by sparse rewards leading to high gradient variance. To address these, this paper proposes a Structured Branch Sampling (SBS) framework. Its core contributions are: (1) tree-structured branching with shared prefixes to minimize redundant computation; (2) a tree-based advantage estimator that improves gradient signal-to-noise ratio; and (3) path-wise reward pruning and depth-adaptive pruning to suppress propagation of low-quality trajectories. Additionally, process-level dense rewards are introduced to mitigate reward sparsity. Evaluated on image and video generation alignment tasks, SBS achieves a 16% improvement in alignment score over strong baselines while reducing training time by 50%, demonstrating substantial gains in efficiency, stability, and performance.
📝 Abstract
Recent advancements in aligning image and video generative models via GRPO have achieved remarkable gains in enhancing human preference alignment. However, these methods still face high computational costs from on-policy rollouts and excessive SDE sampling steps, as well as training instability due to sparse rewards. In this paper, we propose BranchGRPO, a novel method that introduces a branch sampling policy updating the SDE sampling process. By sharing computation across common prefixes and pruning low-reward paths and redundant depths, BranchGRPO substantially lowers the per-update compute cost while maintaining or improving exploration diversity. This work makes three main contributions: (1) a branch sampling scheme that reduces rollout and training cost; (2) a tree-based advantage estimator incorporating dense process-level rewards; and (3) pruning strategies exploiting path and depth redundancy to accelerate convergence and boost performance. Experiments on image and video preference alignment show that BranchGRPO improves alignment scores by 16% over strong baselines, while cutting training time by 50%.