Generative Modeling of Medical Images with Counterfactuals and Causal Principles

Generative Modeling of Medical Images with Counterfactuals and Causal Principles

Contact person: Emre Kavak, Christian Wachinger

Overview
What if we could ask what if questions in medical imaging? Counterfactual reasoning—exploring alternative scenarios like “What would this brain MRI look like if the patient were older or had a different condition?”—offers a powerful way to enhance understanding, fairness, and robustness in medical AI systems. This project develops generative models for medical imaging, particularly brain MRIs, that incorporate causal principles to produce biologically plausible counterfactuals for diverse and equitable medical applications.

Motivation
Medical imaging often involves complex interactions between patient demographics (age, sex, etc.), pathology, and imaging artifacts. Models that fail to account for these causal factors risk being biased or inequitable. By leveraging counterfactuals, we can disentangle these factors, generate meaningful alternative scenarios, and address data imbalances, enabling AI systems to perform better across diverse patient populations.

Objectives

  • Introducing Counterfactuals into Generative Modeling: Build deep generative models (e.g., GANs, VAEs, Diffusion Models) that synthesize realistic medical images conditioned on causal interventions.
  • Data Augmentation for Fairness: Use counterfactuals to generate balanced datasets that mitigate bias and improve equity in AI models.
  • Causal Insights for Medical AI: Explore how counterfactual reasoning helps uncover and correct spurious correlations in medical imaging tasks.

Example Applications

  • Brain MRI Counterfactuals: Generate MRI scans simulating different ages, sexes, or disease conditions to study causal effects on model predictions.
  • Fair AI Systems: Train models on augmented datasets with counterfactual samples to improve fairness across demographic subgroups, such as age or race.
  • Pathology-Specific Analysis: Use counterfactuals to isolate the impact of specific conditions, enabling interpretable and robust diagnostic models.

Student Requirements
Ideal candidates will have:

  • Knowledge of deep learning, particularly generative modeling (GANs, VAEs, Diffusion Models, etc.).
  • An interest in causal inference, counterfactual reasoning, and fairness in AI.
  • A strong willingness to engage with medical imaging data and its unique challenges.
  • Very strong python programming skills (incl. pytorch, and entire scientific stack)

Expected Outcomes

  • A cutting-edge framework for generating realistic, causally informed counterfactual medical images.
  • Demonstrable improvements in the fairness and robustness of medical imaging AI systems through counterfactual data augmentation.
  • A deeper understanding of causal relationships in medical imaging that drive bias and inequities.

 

Related Work 

  • Latent Diffusion Model for Brain MRI generation: https://arxiv.org/pdf/2209.07162
  • Counterfactuals with GANs: https://arxiv.org/pdf/2111.14923
  • Diffusion Causal Models for Counterfactual Estimation: https://arxiv.org/pdf/2202.10166
  • Counterfactual Image Editing: https://arxiv.org/pdf/2403.09683
  • High Fidelity Image Counterfactuals with Probabilistic Causal Models: https://proceedings.mlr.press/v202/de-sousa-ribeiro23a/de-sousa-ribeiro23a.pdf