🤖 AI Summary
This work addresses the failure of conventional Representation Alignment (REPA) in Just Image Transformers (JiT), which stems from an information asymmetry between spatial image denoising and compressed semantic alignment objectives, leading to degraded FID and reduced sample diversity in late-stage training. The study is the first to uncover this failure mechanism and proposes PixelREPA, a novel approach that realigns the alignment objective directly in pixel space. PixelREPA employs a lightweight, shallow Transformer adapter with partial token masking to enable effective pixel-level representation alignment. On ImageNet 256×256, PixelREPA improves JiT-B/16 from an FID of 3.66 to 3.17 and Inception Score (IS) from 275.1 to 284.6, while accelerating convergence by over two-fold. The larger variant, PixelREPA-H/16, achieves a state-of-the-art FID of 1.81 and IS of 317.2.
📝 Abstract
Representation Alignment (REPA) has emerged as a simple way to accelerate Diffusion Transformers training in latent space. At the same time, pixel-space diffusion transformers such as Just image Transformers (JiT) have attracted growing attention because they remove a dependency on a pretrained tokenizer, and then avoid the reconstruction bottleneck of latent diffusion. This paper shows that the REPA can fail for JiT. REPA yields worse FID for JiT as training proceeds and collapses diversity on image subsets that are tightly clustered in the representation space of pretrained semantic encoder on ImageNet. We trace the failure to an information asymmetry: denoising occurs in the high dimensional image space, while the semantic target is strongly compressed, making direct regression a shortcut objective. We propose PixelREPA, which transforms the alignment target and constrains alignment with a Masked Transformer Adapter that combines a shallow transformer adapter with partial token masking. PixelREPA improves both training convergence and final quality. PixelREPA reduces FID from 3.66 to 3.17 for JiT-B$/16$ and improves Inception Score (IS) from 275.1 to 284.6 on ImageNet $256 \times 256$, while achieving $> 2\times$ faster convergence. Finally, PixelREPA-H$/16$ achieves FID$=1.81$ and IS$=317.2$. Our code is available at https://github.com/kaist-cvml/PixelREPA.