🤖 AI Summary
This work proposes a stochastic gradient variational inference method based on the Price gradient estimator, applicable when only the unnormalized log-density of the target distribution is known. The approach unifies convergence analyses in both the Bures–Wasserstein space and the parameter space, and for the first time establishes that black-box variational inference employing the Price gradient achieves the same optimal iteration complexity as Wasserstein variational inference, thereby bridging a key theoretical gap between these two classes of methods. The core innovation lies in the Price gradient’s effective exploitation of Hessian information from the target log-density, which, under a Gaussian variational family, enables provably optimal convergence rates. Experimental results confirm that this utilization of second-order information is pivotal to the observed performance gains.
📝 Abstract
For approximating a target distribution given only its unnormalized log-density, stochastic gradient-based variational inference (VI) algorithms are a popular approach. For example, Wasserstein VI (WVI) and black-box VI (BBVI) perform gradient descent in measure space (Bures-Wasserstein space) and parameter space, respectively. Previously, for the Gaussian variational family, convergence guarantees for WVI have shown superiority over existing results for black-box VI with the reparametrization gradient, suggesting the measure space approach might provide some unique benefits. In this work, however, we close this gap by obtaining identical state-of-the-art iteration complexity guarantees for both. In particular, we identify that WVI's superiority stems from the specific gradient estimator it uses, which BBVI can also leverage with minor modifications. The estimator in question is usually associated with Price's theorem and utilizes second-order information (Hessians) of the target log-density. We will refer to this as Price's gradient. On the flip side, WVI can be made more widely applicable by using the reparametrization gradient, which requires only gradients of the log-density. We empirically demonstrate that the use of Price's gradient is the major source of performance improvement.