Accelerating RL for LLM Reasoning with Optimal Advantage Regression

📅 2025-05-27
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Existing reinforcement learning (RL) fine-tuning methods for enhancing complex reasoning in large language models (LLMs)—such as PPO and GRPO—rely on multi-round sampling and auxiliary critic networks, incurring high computational cost and GPU memory overhead. To address this, we propose A*-PO, a two-stage policy optimization framework: first, an offline estimation of the optimal state-value function $V^*$; second, efficient policy updates via single-prompt generation and least-squares advantage regression. We theoretically show that KL regularization in the objective obviates explicit exploration. Crucially, A*-PO eliminates the need for critic networks and repeated sampling. On mathematical reasoning benchmarks, it matches the performance of state-of-the-art RL algorithms while accelerating training by 2× and reducing peak GPU memory usage by over 30%. This significantly improves the practicality and scalability of RL-based LLM fine-tuning.

Technology Category

Application Category

📝 Abstract
Reinforcement learning (RL) has emerged as a powerful tool for fine-tuning large language models (LLMs) to improve complex reasoning abilities. However, state-of-the-art policy optimization methods often suffer from high computational overhead and memory consumption, primarily due to the need for multiple generations per prompt and the reliance on critic networks or advantage estimates of the current policy. In this paper, we propose $A$*-PO, a novel two-stage policy optimization framework that directly approximates the optimal advantage function and enables efficient training of LLMs for reasoning tasks. In the first stage, we leverage offline sampling from a reference policy to estimate the optimal value function $V$*, eliminating the need for costly online value estimation. In the second stage, we perform on-policy updates using a simple least-squares regression loss with only a single generation per prompt. Theoretically, we establish performance guarantees and prove that the KL-regularized RL objective can be optimized without requiring complex exploration strategies. Empirically, $A$*-PO achieves competitive performance across a wide range of mathematical reasoning benchmarks, while reducing training time by up to 2$ imes$ and peak memory usage by over 30% compared to PPO, GRPO, and REBEL. Implementation of $A$*-PO can be found at https://github.com/ZhaolinGao/A-PO.
Problem

Research questions and friction points this paper is trying to address.

Reducing computational overhead in RL for LLM reasoning
Eliminating need for multiple generations per prompt
Optimizing memory usage during policy optimization
Innovation

Methods, ideas, or system contributions that make the work stand out.

Two-stage policy optimization for RL
Offline sampling for optimal value estimation
Single-generation on-policy updates
🔎 Similar Papers