🤖 AI Summary
Data pruning exacerbates class imbalance and impairs worst-class generalization. Method: We propose Distributionally Robust Data Pruning (DRDP), a framework that jointly optimizes pruning strategies and model robustness via intra-class adaptive sampling and worst-class risk modeling. DRDP is the first to systematically characterize the distributional shift induced by pruning, grounding pruning in theoretically justified distributionally robust optimization (DRO). It incorporates Gaussian mixture modeling for uncertainty-aware sample selection and class-aware dynamic pruning rate scheduling. Results: On standard vision benchmarks, DRDP significantly improves worst-class accuracy under high pruning ratios (average +3.2%), with only marginal degradation in overall accuracy (<0.5%). Crucially, robustness gains increase monotonically with pruning ratio, demonstrating superior worst-case generalization under severe data reduction.
📝 Abstract
In the era of exceptionally data-hungry models, careful selection of the training data is essential to mitigate the extensive costs of deep learning. Data pruning offers a solution by removing redundant or uninformative samples from the dataset, which yields faster convergence and improved neural scaling laws. However, little is known about its impact on classification bias of the trained models. We conduct the first systematic study of this effect and reveal that existing data pruning algorithms can produce highly biased classifiers. We present theoretical analysis of the classification risk in a mixture of Gaussians to argue that choosing appropriate class pruning ratios, coupled with random pruning within classes has potential to improve worst-class performance. We thus propose DRoP, a distributionally robust approach to pruning and empirically demonstrate its performance on standard computer vision benchmarks. In sharp contrast to existing algorithms, our proposed method continues improving distributional robustness at a tolerable drop of average performance as we prune more from the datasets.