Did Models Sufficient Learn? Attribution-Guided Training via Subset-Selected Counterfactual Augmentation

📅 2025-11-15
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Vision models often rely on local, spurious correlations for prediction, resulting in poor out-of-distribution (OOD) generalization. To address this, we propose attribution-guided counterfactual augmentation—a training paradigm centered on Subset-Selected Counterfactual Augmentation (SS-CA). SS-CA leverages LIMA-based attribution to identify the minimal set of critical regions; it then replaces them with natural background patches to generate high-fidelity counterfactual samples, jointly optimizing predictions on both original and augmented inputs. This encourages models to learn complete, robust causal features rather than superficial statistical associations. Evaluated on ImageNet and multiple OOD benchmarks—including ImageNet-R, ImageNet-S, and noise-corrupted variants—our method significantly improves both in-distribution accuracy and OOD generalization. Results empirically validate that enhancing causal learning completeness is essential for robust visual generalization.

Technology Category

Application Category

📝 Abstract
In current visual model training, models often rely on only limited sufficient causes for their predictions, which makes them sensitive to distribution shifts or the absence of key features. Attribution methods can accurately identify a model's critical regions. However, masking these areas to create counterfactuals often causes the model to misclassify the target, while humans can still easily recognize it. This divergence highlights that the model's learned dependencies may not be sufficiently causal. To address this issue, we propose Subset-Selected Counterfactual Augmentation (SS-CA), which integrates counterfactual explanations directly into the training process for targeted intervention. Building on the subset-selection-based LIMA attribution method, we develop Counterfactual LIMA to identify minimal spatial region sets whose removal can selectively alter model predictions. Leveraging these attributions, we introduce a data augmentation strategy that replaces the identified regions with natural background, and we train the model jointly on both augmented and original samples to mitigate incomplete causal learning. Extensive experiments across multiple ImageNet variants show that SS-CA improves generalization on in-distribution (ID) test data and achieves superior performance on out-of-distribution (OOD) benchmarks such as ImageNet-R and ImageNet-S. Under perturbations including noise, models trained with SS-CA also exhibit enhanced generalization, demonstrating that our approach effectively uses interpretability insights to correct model deficiencies and improve both performance and robustness.
Problem

Research questions and friction points this paper is trying to address.

Models rely on limited causes, making them sensitive to distribution shifts.
Attribution methods reveal critical regions, but masking causes model misclassification.
The approach addresses incomplete causal learning to enhance model robustness.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Counterfactual LIMA identifies minimal spatial region sets
SS-CA replaces critical regions with natural background
Joint training on augmented and original samples improves robustness
🔎 Similar Papers
No similar papers found.
Y
Yannan Chen
School of Cyber Science and Technology, Shenzhen Campus of Sun Yat-sen University
Ruoyu Chen
Ruoyu Chen
Institute of Information Engineering, Chinese Academy of Sciences.
Explainable AITrustworthy AIFoundation Model
Bin Zeng
Bin Zeng
National High Magnetic Field Lab
Superconductivity
W
Wei Wang
School of Cyber Science and Technology, Shenzhen Campus of Sun Yat-sen University
S
Shiming Liu
RAMS Lab, Huawei Inc.
Qunli Zhang
Qunli Zhang
Imperial College London
Z
Zheng Hu
RAMS Lab, Munich Research Center, Huawei Düsseldorf GmbH
L
Laiyuan Wang
School of Flexible Electronics, SYSU
Yaowei Wang
Yaowei Wang
The Hong Kong Polytechnic University
Xiaochun Cao
Xiaochun Cao
Sun Yat-sen University
Computer VisionArtificial IntelligenceMultimediaMachine Learning