On a few pitfalls in KL divergence gradient estimation for RL

📅 2025-06-11
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work identifies two fundamental misuses of KL divergence gradient estimation in large language model (LLM) training under reinforcement learning (RL): (1) directly differentiating noisy, sample-based KL estimates as if they were deterministic differentiable losses, yielding biased gradients; and (2) ignoring the sequential decision-making structure by computing gradients only at the final token, resulting in incomplete gradient signals. To address these issues, the authors systematically derive, for the first time, a strictly unbiased KL gradient formulation grounded in policy gradient theory and sequential probability modeling. They further propose a principled implementation paradigm aligned with standard RL training workflows. Empirical validation across tabular RL benchmarks and LLM fine-tuning demonstrates that erroneous gradient implementations induce policy degradation, whereas the corrected gradients substantially improve KL constraint fidelity, convergence stability, and final task performance.

Technology Category

Application Category

📝 Abstract
We point out a few pitfalls in implementing gradient estimation for KL divergence in RL training for LLM, as seen in a number of open source projects and papers. The first major pitfall is to differentiate through the KL estimate as loss functions to minimize KL divergence. We show that such implementations are generally incorrect and do not produce the desired KL gradient. Secondly, we show that some implementations do not account for the sequential nature of the estimation problem and produce a partial gradient at best. We demonstrate the impact of such issues with illustrative tabular and LLM experiments, and show the correct way to implement the KL gradient.
Problem

Research questions and friction points this paper is trying to address.

Incorrect KL divergence gradient estimation in RL
Ignoring sequential nature in KL gradient computation
Demonstrating correct KL gradient implementation methods
Innovation

Methods, ideas, or system contributions that make the work stand out.

Avoid differentiating through KL estimate as loss
Account for sequential nature in gradient estimation
Demonstrate correct KL gradient implementation
🔎 Similar Papers
No similar papers found.