🤖 AI Summary
Traditional flow matching models individual point samples, rendering them inadequate for scenarios where each sample is itself a distribution—e.g., 3D shapes or single-cell spatial transcriptomics. This work elevates flow matching to the Wasserstein space of probability distributions, introducing the first generative framework tailored to distribution families. It employs Wasserstein geodesics as conditional flow paths, enabling unified modeling of both analytic distributions (e.g., Gaussians) and empirical ones (e.g., point clouds). We theoretically establish its optimality and integrate entropy-regularized optimal transport estimation with attention mechanisms for efficient high-dimensional distribution generation. Experiments on 2D/3D geometric modeling and spatial transcriptomics data demonstrate superior distribution-level generation capability, markedly advancing beyond conventional point-data paradigms. The code is open-sourced and supports multimodal distribution generation.
📝 Abstract
Generative modeling typically concerns transporting a single source distribution to a target distribution via simple probability flows. However, in fields like computer graphics and single-cell genomics, samples themselves can be viewed as distributions, where standard flow matching ignores their inherent geometry. We propose Wasserstein flow matching (WFM), which lifts flow matching onto families of distributions using the Wasserstein geometry. Notably, WFM is the first algorithm capable of generating distributions in high dimensions, whether represented analytically (as Gaussians) or empirically (as point-clouds). Our theoretical analysis establishes that Wasserstein geodesics constitute proper conditional flows over the space of distributions, making for a valid FM objective. Our algorithm leverages optimal transport theory and the attention mechanism, demonstrating versatility across computational regimes: exploiting closed-form optimal transport paths for Gaussian families, while using entropic estimates on point-clouds for general distributions. WFM successfully generates both 2D&3D shapes and high-dimensional cellular microenvironments from spatial transcriptomics data. Code is available at https://github.com/DoronHav/WassersteinFlowMatching .