🤖 AI Summary
To address the high computational cost of reinforcement fine-tuning (ReFT) for large language models (LLMs) on complex tasks such as mathematical reasoning—especially due to redundant inference across multiple generated samples—this paper proposes an efficient ReFT framework based on off-policy replay. Methodologically, it (1) reuses selected layers of the target model as a behavior policy and introduces a dynamic layer-skipping mechanism to substantially reduce inference overhead; and (2) integrates off-policy reinforcement learning with speculative decoding, incorporating three bias-mitigation strategies to ensure unbiased gradient estimation with controlled variance. Experiments across multiple mathematical reasoning benchmarks and model scales demonstrate that the approach significantly improves throughput (tokens/sec) and training efficiency, while matching the performance of standard ReFT.
📝 Abstract
Advanced reasoning in LLMs on challenging domains like mathematical reasoning can be tackled using verifiable rewards based reinforced fine-tuning (ReFT). In standard ReFT frameworks, a behavior model generates multiple completions with answers per problem, for the answer to be then scored by a reward function. While such RL post-training methods demonstrate significant performance improvements across challenging reasoning domains, the computational cost of generating completions during training with multiple inference steps makes the training cost non-trivial. To address this, we draw inspiration from off-policy RL, and speculative decoding to introduce a novel ReFT framework, dubbed Nested-ReFT, where a subset of layers of the target model acts as the behavior model to generate off-policy completions during training. The behavior model configured with dynamic layer skipping per batch during training decreases the inference cost compared to the standard ReFT frameworks. Our theoretical analysis shows that Nested-ReFT yields unbiased gradient estimates with controlled variance. Our empirical analysis demonstrates improved computational efficiency measured as tokens/sec across multiple math reasoning benchmarks and model sizes. Additionally, we explore three variants of bias mitigation to minimize the off-policyness in the gradient updates that allows for maintaining performance that matches the baseline ReFT performance.