Inference-Aware Fine-Tuning for Best-of-N Sampling in Large Language Models

📅 2024-12-18
🏛️ International Conference on Learning Representations
📈 Citations: 40
Influential: 2
📄 PDF
🤖 AI Summary
To address the non-differentiability of the argmax operation in Best-of-N (BoN) decoding—hindering end-to-end optimization of large language models (LLMs)—this work proposes a reasoning-aware fine-tuning paradigm. We introduce the first BoN-aware joint training framework combining imitation learning and reinforcement learning, leveraging gradient approximation and policy distillation to circumvent the non-differentiable bottleneck. Our method further reveals that the model implicitly acquires a meta-strategy balancing exploration and exploitation during BoN inference. Experiments on Gemma-2B demonstrate substantial gains: on MATH, Bo32 accuracy improves by 4.0 percentage points (26.8% → 30.8%) and pass@32 by 7.0 points (60.0% → 67.0%); on HumanEval, pass@16 rises by 5.5 points (61.6% → 67.1%). These results confirm significant improvements in both reasoning quality and inference efficiency under BoN sampling.

Technology Category

Application Category

📝 Abstract
Recent studies have indicated that effectively utilizing inference-time compute is crucial for attaining better performance from large language models (LLMs). In this work, we propose a novel inference-aware fine-tuning paradigm, in which the model is fine-tuned in a manner that directly optimizes the performance of the inference-time strategy. We study this paradigm using the simple yet effective Best-of-N (BoN) inference strategy, in which a verifier selects the best out of a set of LLM-generated responses. We devise the first imitation learning and reinforcement learning~(RL) methods for BoN-aware fine-tuning, overcoming the challenging, non-differentiable argmax operator within BoN. We empirically demonstrate that our BoN-aware models implicitly learn a meta-strategy that interleaves best responses with more diverse responses that might be better suited to a test-time input -- a process reminiscent of the exploration-exploitation trade-off in RL. Our experiments demonstrate the effectiveness of BoN-aware fine-tuning in terms of improved performance and inference-time compute. In particular, we show that our methods improve the Bo32 performance of Gemma 2B on Hendrycks MATH from 26.8% to 30.8%, and pass@32 from 60.0% to 67.0%, as well as the pass@16 on HumanEval from 61.6% to 67.1%.
Problem

Research questions and friction points this paper is trying to address.

Optimizing LLM fine-tuning for inference-time compute efficiency
Developing imitation and RL methods for Best-of-N sampling
Improving model performance through exploration-exploitation trade-offs
Innovation

Methods, ideas, or system contributions that make the work stand out.

Fine-tuning optimizes inference-time strategy performance directly
Uses imitation and reinforcement learning for Best-of-N sampling
Learns meta-strategy balancing best and diverse responses
🔎 Similar Papers
No similar papers found.