🤖 AI Summary
Medical visual question answering (VQA) models often suffer from limited generalization due to their reliance on spurious correlations in datasets, such as fixed anatomical patterns or question templates. This work proposes the first end-to-end learnable causal pruning framework that constructs global prototypes via a momentum-updated dynamic anatomical feature bank and introduces a differentiable causal pruning module to adaptively suppress non-causal features highly correlated with these prototypes, thereby enhancing instance-specific causal evidence. The method significantly improves model robustness and generalization across multiple benchmarks—VQA-RAD, SLAKE, SLAKE-CP, and PathVQA—outperforming existing debiasing approaches.
📝 Abstract
Medical Visual Question Answering (MedVQA) models often exhibit limited generalization due to reliance on dataset-specific correlations, such as recurring anatomical patterns or question-type regularities, rather than genuine diagnostic evidence. Existing causal approaches are typically implemented as static adjustments or post-hoc corrections. To address this issue, we propose a Learnable Causal Trimming (LCT) framework that integrates causal pruning into end-to-end optimization. We introduce a Dynamic Anatomical Feature Bank (DAFB), updated via a momentum mechanism, to capture global prototypes of frequent anatomical and linguistic patterns, serving as an approximation of dataset-level regularities. We further design a differentiable trimming module that estimates the dependency between instance-level representations and the global feature bank. Features highly correlated with global prototypes are softly suppressed, while instance-specific evidence is emphasized. This learnable mechanism encourages the model to prioritize causal signals over spurious correlations adaptively. Experiments on VQA-RAD, SLAKE, SLAKE-CP and PathVQA demonstrate that LCT consistently improves robustness and generalization over existing debiasing strategies.