🤖 AI Summary
This work addresses the failure of diffusion loss and performance degradation when variational autoencoders (VAEs) and latent diffusion models (LDMs) are jointly trained end-to-end. We propose Representation Alignment (REPA) loss—the first method enabling fully differentiable, end-to-end co-optimization of VAEs and latent diffusion Transformers. REPA enforces consistency between the VAE’s latent space and the diffusion model’s input space, mitigating gradient conflict and thereby improving VAE reconstruction fidelity and latent structure. On ImageNet 256×256, our method achieves state-of-the-art FID scores of 1.26 (with classifier-free guidance) and 1.83 (without), setting new benchmarks in generative quality. Training is accelerated by 17× over the REPA baseline and 45× over standard LDM. Our core contribution is breaking the joint-training bottleneck, establishing the first efficient, stable, and high-fidelity end-to-end unified VAE-LDM framework.
📝 Abstract
In this paper we tackle a fundamental question:"Can we train latent diffusion models together with the variational auto-encoder (VAE) tokenizer in an end-to-end manner?"Traditional deep-learning wisdom dictates that end-to-end training is often preferable when possible. However, for latent diffusion transformers, it is observed that end-to-end training both VAE and diffusion-model using standard diffusion-loss is ineffective, even causing a degradation in final performance. We show that while diffusion loss is ineffective, end-to-end training can be unlocked through the representation-alignment (REPA) loss -- allowing both VAE and diffusion model to be jointly tuned during the training process. Despite its simplicity, the proposed training recipe (REPA-E) shows remarkable performance; speeding up diffusion model training by over 17x and 45x over REPA and vanilla training recipes, respectively. Interestingly, we observe that end-to-end tuning with REPA-E also improves the VAE itself; leading to improved latent space structure and downstream generation performance. In terms of final performance, our approach sets a new state-of-the-art; achieving FID of 1.26 and 1.83 with and without classifier-free guidance on ImageNet 256 x 256. Code is available at https://end2end-diffusion.github.io.