Sampling and Loss Weights in Multi-Domain Training

📅 2025-11-10
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Training large language models on heterogeneous multi-source data (e.g., Wikipedia, GitHub) suffers from sampling imbalance and loss imbalance, leading to high gradient variance and degraded generalization. Method: This paper systematically analyzes the complementary roles of sampling weights and loss weights in suppressing gradient variance and narrowing the generalization gap. We propose a joint optimization framework grounded in linear regression theory and SGD dynamics, deriving principled co-design criteria for both weight types through theoretical analysis and empirical validation. Contribution/Results: Our key insight is the first formal characterization that sampling and loss weights are not independently tunable but must be jointly configured to simultaneously ensure gradient stability and strong cross-domain generalization. Experiments demonstrate that our method significantly reduces training variance, accelerates convergence, and improves out-of-distribution generalization across diverse domains.

Technology Category

Application Category

📝 Abstract
In the training of large deep neural networks, there is a need for vast amounts of training data. To meet this need, data is collected from multiple domains, such as Wikipedia and GitHub. These domains are heterogeneous in both data quality and the diversity of information they provide. This raises the question of how much we should rely on each domain. Several methods have attempted to address this issue by assigning sampling weights to each data domain using heuristics or approximations. As a first step toward a deeper understanding of the role of data mixing, this work revisits the problem by studying two kinds of weights: sampling weights, which control how much each domain contributes in a batch, and loss weights, which scale the loss from each domain during training. Through a rigorous study of linear regression, we show that these two weights play complementary roles. First, they can reduce the variance of gradient estimates in iterative methods such as stochastic gradient descent (SGD). Second, they can improve generalization performance by reducing the generalization gap. We provide both theoretical and empirical support for these claims. We further study the joint dynamics of sampling weights and loss weights, examining how they can be combined to capture both contributions.
Problem

Research questions and friction points this paper is trying to address.

Optimizing domain sampling weights for heterogeneous training data
Balancing loss weights to improve generalization performance
Studying joint dynamics of sampling and loss weights
Innovation

Methods, ideas, or system contributions that make the work stand out.

Assigns sampling weights to control domain data contribution
Uses loss weights to scale domain loss during training
Combines sampling and loss weights to improve generalization
🔎 Similar Papers
No similar papers found.
M
Mahdi Salmani
University of Southern California
P
Pratik Worah
Google Research
Meisam Razaviyayn
Meisam Razaviyayn
University of Southern California
OptimizationMachine Learning
V
V. Mirrokni
Google Research