🤖 AI Summary
This work addresses the high variance and scalability limitations of stochastic trace estimation—particularly Hutchinson’s estimator—used in divergence-based likelihood training of diffusion models, which cause training instability and degraded sample quality, especially for ill-conditioned Jacobian matrices (large condition numbers). To this end, we propose Hutch++, the first low-variance stochastic trace estimator explicitly designed for generative modeling optimization. Hutch++ is the first to integrate the theoretically grounded Hutch++ algorithm—guaranteeing superior variance convergence rates over standard estimators—into diffusion model training. We further introduce an efficient implementation that avoids repeated QR decompositions, leveraging the inherent low-rank structure of high-dimensional Jacobians. Empirically, on image generation and conditional time-series forecasting tasks, Hutch++ significantly reduces training variance, improves sample fidelity, and enhances training stability. These results demonstrate its dual advantages in both optimal transport consistency and computational efficiency.
📝 Abstract
Hutchinson estimators are widely employed in training divergence-based likelihoods for diffusion models to ensure optimal transport (OT) properties. However, this estimator often suffers from high variance and scalability concerns. To address these challenges, we investigate Hutch++, an optimal stochastic trace estimator for generative models, designed to minimize training variance while maintaining transport optimality. Hutch++ is particularly effective for handling ill-conditioned matrices with large condition numbers, which commonly arise when high-dimensional data exhibits a low-dimensional structure. To mitigate the need for frequent and costly QR decompositions, we propose practical schemes that balance frequency and accuracy, backed by theoretical guarantees. Our analysis demonstrates that Hutch++ leads to generations of higher quality. Furthermore, this method exhibits effective variance reduction in various applications, including simulations, conditional time series forecasts, and image generation.