๐ค AI Summary
To address statistical imbalance between treatment groups that biases Conditional Average Treatment Effect (CATE) estimation in causal inference, this paper proposes a model-agnostic counterfactual data augmentation method. It pioneers the integration of contrastive learning into counterfactual reasoning, constructing a representation space that preserves similarity of potential outcomes and enabling precise counterfactual outcome imputation across treatment groups. Theoretically, the method mitigates treatment group distribution shift and suppresses overfitting. Empirical evaluation on synthetic and semi-synthetic benchmarks demonstrates substantial improvements: average RMSE reduction of 18.7% across mainstream CATE estimators, over 30% decrease in generalization error, and enhanced robustnessโall without reliance on specific model architectures. The core contribution lies in unifying contrastive learning with counterfactual augmentation, establishing a general, interpretable, low-bias, and high-generalization enhancement paradigm for CATE estimation.
๐ Abstract
Statistical disparity between distinct treatment groups is one of the most significant challenges for estimating Conditional Average Treatment Effects (CATE). To address this, we introduce a model-agnostic data augmentation method that imputes the counterfactual outcomes for a selected subset of individuals. Specifically, we utilize contrastive learning to learn a representation space and a similarity measure such that in the learned representation space close individuals identified by the learned similarity measure have similar potential outcomes. This property ensures reliable imputation of counterfactual outcomes for the individuals with close neighbors from the alternative treatment group. By augmenting the original dataset with these reliable imputations, we can effectively reduce the discrepancy between different treatment groups, while inducing minimal imputation error. The augmented dataset is subsequently employed to train CATE estimation models. Theoretical analysis and experimental studies on synthetic and semi-synthetic benchmarks demonstrate that our method achieves significant improvements in both performance and robustness to overfitting across state-of-the-art models.