🤖 AI Summary
This work addresses the challenge of reusing pretrained Softmax attention weights when migrating to linear-complexity attention mechanisms. The authors propose a direct conversion of pretrained Vision Transformers into linear-attention models based on Test-Time Training (TTT) through architectural and representational alignment. They introduce instance normalization and a locality-enhancement module to significantly improve representation consistency. This approach achieves, for the first time, lossless weight transfer from Softmax-based Transformers to linear TTT architectures. With only one hour of fine-tuning on 4×H20 GPUs, the resulting SD3.5-T⁵ model attains 1.32× and 1.47× faster inference at 1K and 2K resolutions, respectively, while preserving generation quality comparable to the original model.
📝 Abstract
While linear-complexity attention mechanisms offer a promising alternative to Softmax attention for overcoming the quadratic bottleneck, training such models from scratch remains prohibitively expensive. Inheriting weights from pretrained Transformers provides an appealing shortcut, yet the fundamental representational gap between Softmax and linear attention prevents effective weight transfer. In this work, we address this conversion challenge from two perspectives: architectural alignment and representational alignment. We identify Test-Time Training (TTT) as a linear-complexity architecture whose two-layer dynamic formulation is structurally aligned with Softmax attention, enabling direct inheritance of pretrained attention weights. To further align representational properties, including key shift-invariance and locality, we introduce key instance normalization and a lightweight locality enhancement module. We validate our approach by linearizing Stable Diffusion 3.5 and introduce SD3.5-T$^5$ (Transformer To Test Time Training). With only 1 hour of fine-tuning on 4$\times$H20 GPUs, SD3.5-T$^5$ achieves comparable text-to-image quality to the fine-tuned Softmax model, while accelerating inference by 1.32$\times$ and 1.47$\times$ at 1K and 2K resolutions.