🤖 AI Summary
Medical diagnostic models often suffer from poor generalizability due to class-feature bias and class imbalance—i.e., overreliance on spurious features strongly correlated with only certain classes. To address this, we propose a class-unbiased training framework that jointly tackles both issues: (1) an inter-class inequality loss explicitly enforces balanced feature–class associations by penalizing disparities in feature attribution across classes; and (2) class-weighted distributionally robust optimization (DRO) improves worst-case performance under long-tailed class distributions. Our method is validated on hybrid datasets combining synthetic and real-world multicenter medical data. Experiments across multiple diagnostic tasks demonstrate consistent gains: average accuracy improvements of +3.2–5.8%, and a 37% reduction in performance standard deviation across classes—indicating markedly enhanced cross-class stability. To our knowledge, this is the first work to simultaneously decouple and model class-feature bias and class imbalance within a unified framework.
📝 Abstract
Medical diagnosis might fail due to bias. In this work, we identified class-feature bias, which refers to models' potential reliance on features that are strongly correlated with only a subset of classes, leading to biased performance and poor generalization on other classes. We aim to train a class-unbiased model (Cls-unbias) that mitigates both class imbalance and class-feature bias simultaneously. Specifically, we propose a class-wise inequality loss which promotes equal contributions of classification loss from positive-class and negative-class samples. We propose to optimize a class-wise group distributionally robust optimization objective-a class-weighted training objective that upweights underperforming classes-to enhance the effectiveness of the inequality loss under class imbalance. Through synthetic and real-world datasets, we empirically demonstrate that class-feature bias can negatively impact model performance. Our proposed method effectively mitigates both class-feature bias and class imbalance, thereby improving the model's generalization ability.