Optimal Multi-Distribution Learning

📅 2023-12-08
🏛️ Annual Conference Computational Learning Theory
📈 Citations: 15
Influential: 2
📄 PDF
🤖 AI Summary
This paper studies the optimal sampling and learning problem for minimizing the worst-distribution risk in multi-distribution learning (MDL). Given a hypothesis class of VC dimension $d$ and $k$ heterogeneous source distributions, we propose the first efficient algorithm achieving the information-theoretically optimal sample complexity—up to logarithmic factors—$Theta((d+k)/varepsilon^2)$, matching the known lower bound. Our method leverages an empirical risk minimization (ERM) oracle, integrates Rademacher complexity analysis, and employs an adaptive, on-demand sampling strategy to output an $varepsilon$-optimal randomized hypothesis. We prove the necessity of randomization and expose fundamental limitations of deterministic approaches. The algorithm resolves three open problems posed at COLT 2023. It is oracle-efficient, applicable to both VC-type and Rademacher-type classes, and achieves simultaneous theoretical optimality and computational feasibility.
📝 Abstract
Multi-distribution learning (MDL), which seeks to learn a shared model that minimizes the worst-case risk across $k$ distinct data distributions, has emerged as a unified framework in response to the evolving demand for robustness, fairness, multi-group collaboration, etc. Achieving data-efficient MDL necessitates adaptive sampling, also called on-demand sampling, throughout the learning process. However, there exist substantial gaps between the state-of-the-art upper and lower bounds on the optimal sample complexity. Focusing on a hypothesis class of Vapnik-Chervonenkis (VC) dimension d, we propose a novel algorithm that yields an varepsilon-optimal randomized hypothesis with a sample complexity on the order of (d+k)/varepsilon^2 (modulo some logarithmic factor), matching the best-known lower bound. Our algorithmic ideas and theory are further extended to accommodate Rademacher classes. The proposed algorithms are oracle-efficient, which access the hypothesis class solely through an empirical risk minimization oracle. Additionally, we establish the necessity of randomization, revealing a large sample size barrier when only deterministic hypotheses are permitted. These findings resolve three open problems presented in COLT 2023 (i.e., citet[Problems 1, 3 and 4]{awasthi2023sample}).
Problem

Research questions and friction points this paper is trying to address.

Minimizing worst-case risk across multiple data distributions
Bridging sample complexity gaps in multi-distribution learning
Developing efficient algorithms for VC and Rademacher classes
Innovation

Methods, ideas, or system contributions that make the work stand out.

Oracle-efficient algorithm using empirical risk minimization
Sample complexity matching lower bounds for VC classes
Extended theory to accommodate Rademacher complexity classes
🔎 Similar Papers
No similar papers found.
Z
Zihan Zhang
Department of Electrical and Computer Engineering, Princeton University
Wenhao Zhan
Wenhao Zhan
Graduate Student, Princeton University
reinforcement learninglarge language modelsstatistics
Y
Yuxin Chen
Department of Statistics and Data Science, University of Pennsylvania
S
Simon S. Du
Paul G. Allen School of Computer Science and Engineering, University of Washington
Jason D. Lee
Jason D. Lee
Associate Professor of EECS & Statistics at UC Berkeley
Machine Learning TheoryMachine LearningArtificial IntelligenceStatisticsOptimization