🤖 AI Summary
Shortcut learning in medical AI—where models erroneously rely on non-clinical imaging artifacts rather than pathologically relevant features—leads to spurious correlations and compromised clinical reliability.
Method: We propose the first XAI-driven, semi-automated framework for shortcut identification and disentanglement, integrating gradient-based interpretability methods (Grad-CAM, Integrated Gradients), causal attribution analysis, adversarial data distillation, and model editing. It enables sample- and pixel-level bias localization and mitigation without expert re-annotation. The framework is architecture-agnostic (supporting CNNs and ViTs) and generalizable across multimodal medical data.
Results: Evaluated on four medical datasets, our approach effectively identifies and eliminates artifact-induced spurious associations, significantly improving out-of-distribution robustness and clinical trustworthiness of VGG16, ResNet50, and ViT models.
📝 Abstract
Deep neural networks are increasingly employed in high-stakes medical applications, despite their tendency for shortcut learning in the presence of spurious correlations, which can have potentially fatal consequences in practice. Detecting and mitigating shortcut behavior is a challenging task that often requires significant labeling efforts from domain experts. To alleviate this problem, we introduce a semi-automated framework for the identification of spurious behavior from both data and model perspective by leveraging insights from eXplainable Artificial Intelligence (XAI). This allows the retrieval of spurious data points and the detection of model circuits that encode the associated prediction rules. Moreover, we demonstrate how these shortcut encodings can be used for XAI-based sample- and pixel-level data annotation, providing valuable information for bias mitigation methods to unlearn the undesired shortcut behavior. We show the applicability of our framework using four medical datasets across two modalities, featuring controlled and real-world spurious correlations caused by data artifacts. We successfully identify and mitigate these biases in VGG16, ResNet50, and contemporary Vision Transformer models, ultimately increasing their robustness and applicability for real-world medical tasks.