🤖 AI Summary
Models often fail to generalize to minority groups in multi-source heterogeneous data due to group-level distributional uncertainty—arising from label noise, non-stationarity, and estimation bias in group distribution.
Method: We propose a novel Wasserstein distributionally robust optimization (DRO) framework that explicitly models group-level distributional uncertainty within the worst-group objective—departing from conventional DRO assumptions that require precise group distribution estimates. We formulate a joint Wasserstein-DRO problem and design a gradient descent-ascent algorithm with theoretical convergence guarantees.
Results: Extensive experiments on real-world non-stationary datasets demonstrate that our method significantly outperforms state-of-the-art approaches, particularly under distribution shift and label noise. It achieves superior robustness and generalization to minority groups, validating the efficacy of explicitly incorporating group-level ambiguity into the DRO formulation.
📝 Abstract
The performance of machine learning (ML) models critically depends on the quality and representativeness of the training data. In applications with multiple heterogeneous data generating sources, standard ML methods often learn spurious correlations that perform well on average but degrade performance for atypical or underrepresented groups. Prior work addresses this issue by optimizing the worst-group performance. However, these approaches typically assume that the underlying data distributions for each group can be accurately estimated using the training data, a condition that is frequently violated in noisy, non-stationary, and evolving environments. In this work, we propose a novel framework that relies on Wasserstein-based distributionally robust optimization (DRO) to account for the distributional uncertainty within each group, while simultaneously preserving the objective of improving the worst-group performance. We develop a gradient descent-ascent algorithm to solve the proposed DRO problem and provide convergence results. Finally, we validate the effectiveness of our method on real-world data.