🤖 AI Summary
This work addresses the common issue in large language models where the reasoning process in chain-of-thought prompting becomes decoupled from the final answer, undermining both interpretability and faithfulness. To jointly optimize correctness and explainability, the authors propose a differentiable attention manipulation method within a GRPO reinforcement learning framework. Specifically, they introduce a learnable additive attention mask to identify critical reasoning tokens and construct an attention saliency–based reward signal accordingly. Experimental results demonstrate that this approach effectively enhances the faithfulness and transparency of reasoning trajectories for the Llama-3.2-3B-Instruct model on the GSM8K and MMLU benchmarks.
📝 Abstract
Large language models (LLMs) increasingly rely on chain-of-thought (CoT) reasoning to solve complex tasks. Yet ensuring that the reasoning trace both contributes to and faithfully reflects the processes underlying the model's final answer, rather than merely accompanying it, remains challenging. We introduce AtManRL, a method that leverages differentiable attention manipulation to learn more faithful reasoning through reinforcement learning. By training an additive attention mask that identifies tokens in the CoT crucial for producing correct answers, we derive a saliency reward signal that encourages the model to generate reasoning traces that genuinely influence its final predictions. We integrate this saliency reward with outcome-based rewards within the GRPO framework to jointly optimize for correctness and interpretability. Experiments on GSM8K and MMLU with Llama-3.2-3B-Instruct demonstrate that our approach can identify influential reasoning tokens and enable training more transparent reasoning models.