🤖 AI Summary
To address the challenge of efficient fine-tuning of large language models (LLMs) in multi-task, multi-dataset settings, this paper proposes a lightweight ensemble method based on grouped adapters. Leveraging the first-order approximation property of LoRA, our approach rapidly estimates dataset-combination performance via base-model gradients, enabling dynamic adapter grouping and weighted ensemble—eliminating the need to train large, task-specific adapters. This is the first method to significantly outperform QLoRA in multi-task setups with minimal overhead: only an 8% increase in computational cost. On ten text classification tasks, it achieves a 10% average accuracy gain; on a 34B-parameter Llama model, it improves accuracy by 3%. Performance prediction error remains below 5%, and the search process accelerates by up to 105× compared to exhaustive baselines.
📝 Abstract
This paper develops an ensemble method for fine-tuning a language model to multiple datasets. Existing methods, such as quantized LoRA (QLoRA), are efficient when adapting to a single dataset. When training on multiple datasets of different tasks, a common setup in practice, it remains unclear how to design an efficient adaptation for fine-tuning language models. We propose to use an ensemble of multiple smaller adapters instead of a single adapter per task. We design an efficient algorithm that partitions $n$ datasets into $m$ groups, where $m$ is typically much smaller than $n$ in practice, and train one adapter for each group before taking a weighted combination to form the ensemble. The algorithm leverages a first-order approximation property of low-rank adaptation to quickly obtain the fine-tuning performances of dataset combinations since methods like LoRA stay close to the base model. Hence, we use the gradients of the base model to estimate its behavior during fine-tuning. Empirically, this approximation holds with less than $1%$ error on models with up to $34$ billion parameters, leading to an estimation of true fine-tuning performances under $5%$ error while speeding up computation compared to base fine-tuning by $105$ times. When applied to fine-tune Llama and GPT models on ten text classification tasks, our approach provides up to $10%$ higher average test accuracy over QLoRA, with only $9%$ more FLOPs. On a Llama model with $34$ billion parameters, an ensemble of QLoRA increases test accuracy by $3%$ compared to QLoRA, with only $8%$ more FLOPs.