🤖 AI Summary
In Gaussian variational inference on the Bures–Wasserstein manifold, energy gradients involve intractable expectations, and existing single-sample Monte Carlo estimators suffer from high variance and slow convergence.
Method: We introduce control variates into this framework for the first time, proposing a provably variance-reduced stochastic gradient estimator. Our method is grounded in forward–backward Euler discretization of the Wasserstein gradient flow, preserving the Gaussian variational family’s structure while enabling efficient manifold optimization.
Contribution/Results: We theoretically establish that the proposed estimator improves the optimization convergence rate. Empirical evaluation on standard benchmarks demonstrates order-of-magnitude gains in both convergence speed and estimation accuracy over state-of-the-art Bures–Wasserstein variational inference methods.
📝 Abstract
Optimization in the Bures-Wasserstein space has been gaining popularity in the machine learning community since it draws connections between variational inference and Wasserstein gradient flows. The variational inference objective function of Kullback-Leibler divergence can be written as the sum of the negative entropy and the potential energy, making forward-backward Euler the method of choice. Notably, the backward step admits a closed-form solution in this case, facilitating the practicality of the scheme. However, the forward step is not exact since the Bures-Wasserstein gradient of the potential energy involves"intractable"expectations. Recent approaches propose using the Monte Carlo method -- in practice a single-sample estimator -- to approximate these terms, resulting in high variance and poor performance. We propose a novel variance-reduced estimator based on the principle of control variates. We theoretically show that this estimator has a smaller variance than the Monte-Carlo estimator in scenarios of interest. We also prove that variance reduction helps improve the optimization bounds of the current analysis. We demonstrate that the proposed estimator gains order-of-magnitude improvements over the previous Bures-Wasserstein methods.