Bayesian Meta-Learning for Improving Generalizability of Health Prediction Models With Similar Causal Mechanisms

📅 2023-10-19
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Medical prediction models often suffer from poor generalizability and negative transfer across heterogeneous populations. Method: We propose the first Bayesian meta-learning framework integrating causal mechanism similarity modeling. It employs causal representation learning to identify stable, shared mechanisms across tasks and dynamically constructs multi-task Bayesian priors; during meta-training, it selects beneficial tasks to mitigate negative transfer, while during meta-testing, it enables personalized predictions for new patients and cross-cohort adaptation (e.g., UK Biobank → FinnGen). Contribution/Results: Compared to standard meta-learning, non-causal similarity-based approaches, and single-task baselines, our framework significantly improves disease prediction accuracy and robustness across diverse cohorts. It establishes a novel paradigm for generalizable, EHR-driven medical AI grounded in causal invariance and Bayesian uncertainty quantification.
📝 Abstract
Machine learning strategies like multi-task learning, meta-learning, and transfer learning enable efficient adaptation of machine learning models to specific applications in healthcare, such as prediction of various diseases, by leveraging generalizable knowledge across large datasets and multiple domains. In particular, Bayesian meta-learning methods pool data across related prediction tasks to learn prior distributions for model parameters, which are then used to derive models for specific tasks. However, inter- and intra-task variability due to disease heterogeneity and other patient-level differences pose challenges of negative transfer during shared learning and poor generalizability to new patients. We introduce a novel Bayesian meta-learning approach that aims to address this in two key settings: (1) predictions for new patients (same population as the training set) and (2) adapting to new patient populations. Our main contribution is in modeling similarity between causal mechanisms of the tasks, for (1) mitigating negative transfer during training and (2) fine-tuning that pools information from tasks that are expected to aid generalizability. We propose an algorithm for implementing this approach for Bayesian deep learning, and apply it to a case study for stroke prediction tasks using electronic health record data. Experiments for the UK Biobank dataset as the training population demonstrated significant generalizability improvements compared to standard meta-learning, non-causal task similarity measures, and local baselines (separate models for each task). This was assessed for a variety of tasks that considered both new patients from the training population (UK Biobank) and a new population (FinnGen).
Problem

Research questions and friction points this paper is trying to address.

Machine Learning
Medical Diagnosis
Personalized Prediction
Innovation

Methods, ideas, or system contributions that make the work stand out.

Bayesian Meta-Learning
Causal Similarity
Stroke Prediction
🔎 Similar Papers
No similar papers found.
S
Sophie Wharrie
Department of Computer Science, Aalto University, Espoo, Finland
L
Lisa Eick
Institute for Molecular Medicine Finland, Helsinki Institute of Life Science, University of Helsinki, Helsinki, Finland
L
Lotta Makinen
Department of Computer Science, Aalto University, Espoo, Finland
A
Andrea Ganna
Institute for Molecular Medicine Finland, Helsinki Institute of Life Science, University of Helsinki, Helsinki, Finland; Massachusetts General Hospital and Broad Institute of MIT and Harvard, Cambridge, MA, USA
Samuel Kaski
Samuel Kaski
Director, ELLIS Institute Finland; Professor, Aalto University and University of Manchester
Probabilistic machine learningAI4ScienceCollaborative AI
F
FinnGen
Department of Computer Science, Aalto University, Espoo, Finland