🤖 AI Summary
This work addresses the computational inefficiency of diffusion model sampling caused by the high cost of evaluating high-precision drift estimators. It introduces, for the first time, the multilevel Monte Carlo (MLMC) framework into diffusion sampling and proposes the Multilevel Euler–Maruyama (ML-EM) method. By constructing a hierarchy of UNet-based drift approximators with increasing accuracy and decreasing computational cost, ML-EM significantly reduces the expense of solving the underlying stochastic differential equation while preserving solution accuracy. Theoretical analysis shows that, under the high-temperature Metropolis–corrected (HTMC) regime, the method achieves polynomial acceleration, with an overall computational cost equivalent to a single evaluation of a high-precision drift estimator. Empirical results on CelebA 64×64 image generation demonstrate a practical speedup of 4× (γ≈2.5), confirming both its theoretical advantages and practical efficacy.
📝 Abstract
We introduce the Multilevel Euler-Maruyama (ML-EM) method compute solutions of SDEs and ODEs using a range of approximators $f^1,\dots,f^k$ to the drift $f$ with increasing accuracy and computational cost, only requiring a few evaluations of the most accurate $f^k$ and many evaluations of the less costly $f^1,\dots,f^{k-1}$. If the drift lies in the so-called Harder than Monte Carlo (HTMC) regime, i.e. it requires $ε^{-γ}$ compute to be $ε$-approximated for some $γ>2$, then ML-EM $ε$-approximates the solution of the SDE with $ε^{-γ}$ compute, improving over the traditional EM rate of $ε^{-γ-1}$. In other terms it allows us to solve the SDE at the same cost as a single evaluation of the drift. In the context of diffusion models, the different levels $f^{1},\dots,f^{k}$ are obtained by training UNets of increasing sizes, and ML-EM allows us to perform sampling with the equivalent of a single evaluation of the largest UNet. Our numerical experiments confirm our theory: we obtain up to fourfold speedups for image generation on the CelebA dataset downscaled to 64x64, where we measure a $γ\approx2.5$. Given that this is a polynomial speedup, we expect even stronger speedups in practical applications which involve orders of magnitude larger networks.