🤖 AI Summary
This work addresses the limitation of conventional policy gradient methods in language model reasoning, which treat all generated tokens uniformly and fail to distinguish critical reasoning steps from auxiliary text. The authors propose a causal credit assignment approach that requires neither auxiliary models nor external annotations. By masking specific reasoning segments and measuring their impact on the probability of the final answer, the method dynamically weights token importance during policy gradient updates. Experiments on GSM8K with Qwen and Llama series models demonstrate that this causal weighting significantly improves performance and accelerates convergence compared to uniform credit assignment. Furthermore, reversing the weighting scheme degrades performance, confirming both the effectiveness and necessity of the proposed causal credit assignment mechanism.
📝 Abstract
Policy gradient methods for language model reasoning, such as GRPO and DAPO, assign uniform credit to all generated tokens - the filler phrase"Let me think"receives the same gradient update as the critical calculation"23 + 45 = 68."We propose counterfactual importance weighting: mask reasoning spans, measure the drop in answer probability, and upweight tokens accordingly during policy gradient updates. Our method requires no auxiliary models or external annotation, instead importance is estimated directly from the policy model's own probability shifts. Experiments on GSM8K across three models spanning the Qwen and Llama families demonstrate consistent improvements over uniform baselines and faster convergence to equivalent accuracy. Inverting the importance signal hurts performance, confirming we capture genuine causal structure rather than noise. Analysis shows the method correctly prioritizes calculation steps over scaffolding text. We view these findings as establishing counterfactual importance weighting as a foundation for further research rather than a complete solution.