🤖 AI Summary
This work addresses performance limitations of large language models in reinforcement learning caused by policy instability and biased KL divergence estimation. To mitigate these issues, the authors propose two key innovations: first, replacing the fixed reference policy with an exponential moving average (EMA) anchor policy to enhance stability; second, introducing a Top-k KL divergence estimator that enables, for the first time, unbiased and flexible estimation of both the KL divergence and its gradient. Implemented within the GRPO framework, the proposed method substantially improves model performance—boosting the accuracy of Qwen-1.5B to 53.9% on OlympiadBench and achieving an average gain of 33.3% across seven agent-based question-answering benchmarks, including an increase from 29.7% to 44.1% on HotpotQA.
📝 Abstract
Reinforcement Learning (RL) has enabled Large Language Models (LLMs) to acquire increasingly complex reasoning and agentic behaviors. In this work, we propose two simple techniques to improve policy gradient algorithms for LLMs. First, we replace the fixed anchor policy during RL with an Exponential Moving Average (EMA), similar to a target network in deep Q-learning. Second, we introduce Top-k KL estimator, which allows for flexible interpolation between exact KL and sampled KL. We derive the stability conditions for using EMA anchor; moreover, we show that our Top-k KL estimator yields both unbiased KL values and unbiased gradients at any k, while bringing the benefits of exact KL. When combined with GRPO, the two techniques (EMA-PG) lead to a significant performance boost. On math reasoning, it allows R1-distilled Qwen-1.5B to reach 53.9% on OlympiadBench compared to 50.8% by GRPO. On agentic RL domains, with Qwen-3B base, EMA-PG improves GRPO by an average of 33.3% across 7 datasets of Q&A with search engines, including 29.7% $\rightarrow$ 44.1% on HotpotQA, 27.4% $\rightarrow$ 40.1% on 2WikiMultiHopQA. Overall, we show that EMA-PG is a simple, principled, and powerful approach to scaling RL for LLMs. Code: https://github.com/LunjunZhang/ema-pg