๐ค AI Summary
In reinforcement learning from human feedback with reward modeling (RLVR), high variance in policy gradient estimation leads to unstable training. To address this, we introduce Stein shrinkage into baseline design for the first time, proposing a cross-prompt reward mean shrinkage estimator that requires no additional hyperparameters or computational overhead. Our method jointly leverages local (per-prompt) and global (batch-level) reward statistics to achieve improved reward centering within policy gradient frameworks such as GRPO. We provide theoretical proof that the estimator strictly reduces gradient variance. Empirical results demonstrate significantly enhanced training stability compared to the empirical mean baseline, particularly under low-sample generation settingsโwhere our approach yields consistently superior and more robust performance across diverse tasks and model scales.
๐ Abstract
Reinforcement Learning with Verifiable Rewards (RLVR) has emerged as a powerful paradigm for post-training large reasoning models (LRMs) using policy-gradient methods such as GRPO. To stabilize training, these methods typically center trajectory rewards by subtracting the empirical mean for each prompt. Statistically, this centering acts as a control variate (or baseline), reducing the variance of the policy-gradient estimator. Typically, the mean reward is estimated using per-prompt empirical averages for each prompt in a batch. Drawing inspiration from Stein's paradox, we propose using shrinkage estimators that combine per-prompt and across-prompt means to improve the overall per-prompt mean estimation accuracy -- particularly in the low-generation regime typical of RLVR. Theoretically, we construct a shrinkage-based baseline that provably yields lower-variance policy-gradient estimators across algorithms. Our proposed baseline serves as a drop-in replacement for existing per-prompt mean baselines, requiring no additional hyper-parameters or computation. Empirically, shrinkage baselines consistently outperform standard empirical-mean baselines, leading to lower-variance gradient updates and improved training stability.