Learning Discrete Autoregressive Priors with Wasserstein Gradient Flow

๐Ÿ“… 2026-05-07
๐Ÿ“ˆ Citations: 0
โœจ Influential: 0
๐Ÿ“„ PDF

career value

172K/year
๐Ÿค– AI Summary
This work addresses the misalignment between discrete image tokenizers and autoregressive priors in two-stage generative modeling, where the prior struggles to accurately predict token sequences during generation. To bridge this gap, the authors propose incorporating a distribution-level prior matching signal into tokenizer training. While preserving the original reconstruction objective, they optimize the token distribution via Wasserstein gradient flows to better align with autoregressive generation. Leveraging a Tripartite Variational Consistency analysis framework, they introduceโ€”for the first timeโ€”a backpropagation-free mechanism for aligning the tokenizer with the autoregressive prior during training. Experiments demonstrate that this approach significantly reduces autoregressive loss and improves generation quality (as measured by FID) on CIFAR-10 and ImageNet, while maintaining comparable reconstruction fidelity.
๐Ÿ“ Abstract
Discrete image tokenizers are commonly trained in two stages: first for reconstruction, and then with a prior model fitted to the frozen token sequences. This decoupling leaves the tokenizer unaware of the model that will later generate its tokens. As a result, the learned tokens may preserve image information well but still be difficult for an autoregressive (AR) prior to predict from left to right. We analyze this mismatch using Tripartite Variational Consistency (TVC), which decomposes latent-variable learning into three consistency conditions: conditional-likelihood consistency, prior consistency, and posterior consistency. TVC shows that two-stage training preserves the reconstruction side but leaves prior consistency outside the tokenizer objective: the overall token distribution is fixed before the AR prior participates in training. Motivated by this view, we add a distribution-level prior-matching signal during tokenizer training, while keeping the reconstruction objective unchanged. We optimize this signal with a Wasserstein-gradient-flow update. For hard categorical tokens, the update reduces to a token-level contrast between an auxiliary AR model that tracks the tokenizer's current token distribution and the target AR prior. It requires only forward passes through the two AR models and does not backpropagate through either of them. The resulting tokenizer, wAR-Tok, reduces AR loss and improves generation FID on CIFAR-10 and ImageNet at comparable reconstruction quality.
Problem

Research questions and friction points this paper is trying to address.

discrete tokenization
autoregressive prior
two-stage training
prior consistency
image generation
Innovation

Methods, ideas, or system contributions that make the work stand out.

Wasserstein gradient flow
discrete tokenizer
autoregressive prior
Tripartite Variational Consistency
prior-consistency