đ¤ AI Summary
Discrete variational autoencoders (VAEs) suffer from non-differentiable discrete latent variables, necessitating biased or high-variance approximationsâsuch as Gumbel-Softmax or REINFORCEâfor gradient estimation, which limits high-fidelity image reconstruction. To address this, we propose a reparameterization-free training framework: a nonparametric encoder serves as a proxy to guide the optimization of a parametric encoder via natural gradient updates; integrated with a Transformer-based encoder and an automatic step-size adaptation mechanism, the framework enables end-to-end training. Our key contribution is the first application of natural gradient optimizationâoriginally developed for policy learningâto discrete VAEs, effectively circumventing the biasâvariance trade-off inherent in conventional estimators. On ImageNet 256, our method achieves a 20% improvement in FrĂŠchet Inception Distance (FID) over strong baselines, including vector quantized VAEs and Gumbel-Softmax VAEs, demonstrating that high-quality image reconstruction is feasible even under highly compact discrete latent representations.
đ Abstract
Discrete latent bottlenecks in variational autoencoders (VAEs) offer high bit efficiency and can be modeled with autoregressive discrete distributions, enabling parameter-efficient multimodal search with transformers. However, discrete random variables do not allow for exact differentiable parameterization; therefore, discrete VAEs typically rely on approximations, such as Gumbel-Softmax reparameterization or straight-through gradient estimates, or employ high-variance gradient-free methods such as REINFORCE that have had limited success on high-dimensional tasks such as image reconstruction. Inspired by popular techniques in policy search, we propose a training framework for discrete VAEs that leverages the natural gradient of a non-parametric encoder to update the parametric encoder without requiring reparameterization. Our method, combined with automatic step size adaptation and a transformer-based encoder, scales to challenging datasets such as ImageNet and outperforms both approximate reparameterization methods and quantization-based discrete autoencoders in reconstructing high-dimensional data from compact latent spaces, achieving a 20% improvement on FID Score for ImageNet 256.