π€ AI Summary
Existing gradient manipulation methods in multi-task learning (MTL) incur computational overhead linear in the number of tasks $K$ ($mathcal{O}(K)$), limiting scalability. To address this, we propose BiLB4MTLβthe first scalable loss-balancing framework for MTL grounded in bilevel optimization: the upper level automatically learns task weights to approach Pareto optimality, while the lower level performs model updates. By introducing initial loss normalization and a first-order approximation algorithm, BiLB4MTL reduces both time and memory complexity to $mathcal{O}(1)$. We provide theoretical guarantees showing convergence to a Pareto-stationary point. Empirically, BiLB4MTL achieves state-of-the-art accuracy across multiple MTL benchmarks while substantially reducing training cost. This work delivers an efficient, scalable solution for large-scale MTL, overcoming key bottlenecks in existing gradient-based task balancing approaches.
π Abstract
Multi-task learning (MTL) has been widely adopted for its ability to simultaneously learn multiple tasks. While existing gradient manipulation methods often yield more balanced solutions than simple scalarization-based approaches, they typically incur a significant computational overhead of $mathcal{O}(K)$ in both time and memory, where $K$ is the number of tasks. In this paper, we propose BiLB4MTL, a simple and scalable loss balancing approach for MTL, formulated from a novel bilevel optimization perspective. Our method incorporates three key components: (i) an initial loss normalization, (ii) a bilevel loss-balancing formulation, and (iii) a scalable first-order algorithm that requires only $mathcal{O}(1)$ time and memory. Theoretically, we prove that BiLB4MTL guarantees convergence not only to a stationary point of the bilevel loss balancing problem but also to an $epsilon$-accurate Pareto stationary point for all $K$ loss functions under mild conditions. Extensive experiments on diverse multi-task datasets demonstrate that BiLB4MTL achieves state-of-the-art performance in both accuracy and efficiency. Code is available at https://github.com/OptMN-Lab/-BiLB4MTL.