🤖 AI Summary
This work addresses the lack of theoretical understanding regarding why Transformers effectively learn optimal denoisers in diffusion models, particularly the mechanism by which they converge to the Bayes-optimal solution under non-convex loss. We establish, for the first time, a global convergence theory for Transformers trained on denoising diffusion probabilistic models (DDPMs) within a multi-label Gaussian mixture setting. By analyzing the population DDPM objective, modeling the multi-label Gaussian mixture distribution, and dissecting the self-attention architecture, we reveal how self-attention implements mean-field denoising and asymptotically approaches the minimum mean squared error (MMSE) estimator. Our analysis quantifies the required number of tokens per sample and training iterations to achieve a prescribed score-matching error. Numerical experiments corroborate the theoretical predictions and demonstrate alignment with MMSE estimation.
📝 Abstract
Transformer-based diffusion models have demonstrated remarkable performance at generating high-quality samples. However, our theoretical understanding of the reasons for this success remains limited. For instance, existing models are typically trained by minimizing a denoising objective, which is equivalent to fitting the score function of the training data. However, we do not know why transformer-based models can match the score function for denoising, or why gradient-based methods converge to the optimal denoising model despite the non-convex loss landscape. To the best of our knowledge, this paper provides the first convergence analysis for training transformer-based diffusion models. More specifically, we consider the population Denoising Diffusion Probabilistic Model (DDPM) objective for denoising data that follow a multi-token Gaussian mixture distribution. We theoretically quantify the required number of tokens per data point and training iterations for the global convergence towards the Bayes optimal risk of the denoising objective, thereby achieving a desired score matching error. A deeper investigation reveals that the self-attention module of the trained transformer implements a mean denoising mechanism that enables the trained model to approximate the oracle Minimum Mean Squared Error (MMSE) estimator of the injected noise in the diffusion steps. Numerical experiments validate these findings.