LASeR: Learning to Adaptively Select Reward Models with Multi-Armed Bandits

📅 2024-10-02
🏛️ arXiv.org
📈 Citations: 3
Influential: 0
📄 PDF
🤖 AI Summary
To address poor generalization, reward signal conflicts, and high computational overhead in joint optimization of multiple reward models (RMs), this paper proposes a dynamic RM selection mechanism based on the multi-armed bandit (MAB) framework, formulating RM scheduling as an online learning problem to enable task- and instance-level adaptive preference data generation. The method innovatively integrates Upper Confidence Bound (UCB) exploration, RM ensembling, best-of-n sampling, and iterative RLHF. Evaluated on Llama-3-8B, it improves average accuracy on commonsense and mathematical reasoning by 2.67% while doubling training speed; achieves a 71.45% win rate on WildChat under AlpacaEval; and boosts F1 score by 2.64 points on long-context QA tasks. This work is the first to formalize RM selection as an MAB problem, effectively balancing alignment performance, robustness against reward conflicts, and training efficiency.

Technology Category

Application Category

📝 Abstract
Reward Models (RMs) play a crucial role in aligning LLMs with human preferences, enhancing their performance by ranking outputs during inference or iterative training. However, the degree to which an RM generalizes to new tasks is often not known a priori (e.g. some RMs may excel at scoring creative writing vs. math reasoning). Therefore, using only one fixed RM while training LLMs can be suboptimal. Moreover, optimizing LLMs with multiple RMs simultaneously can be prohibitively computationally-intensive and challenging due to conflicting signals from different RMs, potentially degrading performance. To address these challenges, we introduce LASeR (Learning to Adaptively Select Rewards), which iteratively trains LLMs using multiple RMs, selecting and utilizing the most well-suited RM for each instance to rank outputs and generate preference data, framed as a multi-armed bandit problem. Our results on commonsense and math reasoning tasks demonstrate that LASeR can boost iterative LLM optimization by optimizing for multiple RMs, improving the absolute average accuracy of Llama-3-8B over three datasets by 2.67% over training with ensemble RM scores while also showing superior training efficiency (e.g., a 2x speedup). Moreover, on WildChat, a benchmark of instruction-following prompts, we find that using Llama-3-8B LASeR leads to a 71.45% AlpacaEval win rate over sequentially optimizing multiple RMs. Extending to long-context generation tasks, we find that on Llama-3-8B, LASeR achieves an average improvement of 2.64 F1 and 2.42 F1 on single- and multi-document QA over random RM selection when used with best-of-n sampling. LASeR is robust to noisy rewards and generalizes to multiple settings. Finally, LASeR's RM selection changes depending on the underlying task or instance and we verify the presence of conflicting preferences from multiple RMs that can be mitigated using LASeR.
Problem

Research questions and friction points this paper is trying to address.

Optimizing LLMs with multiple RMs is computationally expensive
Fixed RMs may not generalize well to new tasks
Conflicting signals from different RMs degrade performance
Innovation

Methods, ideas, or system contributions that make the work stand out.

Adaptive RM selection via multi-armed bandits
Efficient training with instance-specific RM choice
Boosts LLM accuracy and computational efficiency
🔎 Similar Papers
No similar papers found.