Gradient Extrapolation for Debiased Representation Learning

📅 2025-03-17
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Machine learning models trained via empirical risk minimization (ERM) often exploit spurious correlations, leading to poor generalization under distribution shifts. To address this, we propose Gradient-Extrapolated Representation NEutralization (GERNE), a representation debiasing framework that employs a dual-batch gradient linear extrapolation mechanism to learn unbiased representations—whether bias attributes are known or unknown. GERNE unifies ERM, reweighting, and resampling as special cases within a single principled formulation. We theoretically derive upper and lower bounds on the extrapolation factor and support joint optimization of both group-balanced accuracy (GBA) and worst-group accuracy (WGA). Extensive experiments across five vision and one NLP benchmark demonstrate that GERNE consistently improves worst-group and group-balanced performance, achieving state-of-the-art or superior results.

Technology Category

Application Category

📝 Abstract
Machine learning classification models trained with empirical risk minimization (ERM) often inadvertently rely on spurious correlations. When absent in the test data, these unintended associations between non-target attributes and target labels lead to poor generalization. This paper addresses this problem from a model optimization perspective and proposes a novel method, Gradient Extrapolation for Debiased Representation Learning (GERNE), designed to learn debiased representations in both known and unknown attribute training cases. GERNE uses two distinct batches with different amounts of spurious correlations to define the target gradient as the linear extrapolation of two gradients computed from each batch's loss. It is demonstrated that the extrapolated gradient, if directed toward the gradient of the batch with fewer amount of spurious correlation, can guide the training process toward learning a debiased model. GERNE can serve as a general framework for debiasing with methods, such as ERM, reweighting, and resampling, being shown as special cases. The theoretical upper and lower bounds of the extrapolation factor are derived to ensure convergence. By adjusting this factor, GERNE can be adapted to maximize the Group-Balanced Accuracy (GBA) or the Worst-Group Accuracy. The proposed approach is validated on five vision and one NLP benchmarks, demonstrating competitive and often superior performance compared to state-of-the-art baseline methods.
Problem

Research questions and friction points this paper is trying to address.

Addresses poor generalization due to spurious correlations in ERM-trained models.
Proposes GERNE for debiased representation learning in known and unknown attributes.
Validates GERNE on vision and NLP benchmarks, showing superior performance.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Gradient extrapolation for debiased learning
Two-batch method to reduce spurious correlations
Adjustable factor for optimizing group accuracy