🤖 AI Summary
In multi-task learning, static or heuristic data mixing strategies prove inefficient under weak gradient conflict—e.g., multilingual or multi-domain large language model (LLM) pretraining. To address this, we propose GradMix: an adaptive task sampling method grounded in forward gradient interaction. Its core innovation lies in the first use of gradient inner products to estimate *forward synergistic effects* among tasks, enabling online task weight optimization. We provide theoretical convergence guarantees and formal fairness assurances across tasks. GradMix incurs zero additional computational overhead and requires no architectural modifications. Extensive experiments on large-scale LLM pretraining demonstrate that GradMix significantly accelerates convergence and consistently outperforms mainstream baselines—achieving higher average downstream task performance.
📝 Abstract
Modern machine learning models are trained on diverse datasets and tasks to improve generalization. A key challenge in multitask learning is determining the optimal data mixing and sampling strategy across different data sources. Prior research in this multi-task learning setting has primarily focused on mitigating gradient conflicts between tasks. However, we observe that many real-world multitask learning scenarios-such as multilingual training and multi-domain learning in large foundation models-exhibit predominantly positive task interactions with minimal or no gradient conflict. Building on this insight, we introduce PiKE (Positive gradient interaction-based K-task weights Estimator), an adaptive data mixing algorithm that dynamically adjusts task contributions throughout training. PiKE optimizes task sampling to minimize overall loss, effectively leveraging positive gradient interactions with almost no additional computational overhead. We establish theoretical convergence guarantees for PiKE and demonstrate its superiority over static and non-adaptive mixing strategies. Additionally, we extend PiKE to promote fair learning across tasks, ensuring balanced progress and preventing task underrepresentation. Empirical evaluations on large-scale language model pretraining show that PiKE consistently outperforms existing heuristic and static mixing strategies, leading to faster convergence and improved downstream task performance.