๐ค AI Summary
To address the low training efficiency and high computational cost of diffusion models, this paper proposes an architecture-agnostic token routing mechanism that dynamically prunes intermediate-layer computations without modifying the underlying network architecture. The method employs a predefined routing policy to enable multi-path collaboration, adaptive weighted auxiliary losses, and feature fusionโmarking the first unified acceleration framework supporting both Transformer-based and state-space model (SSM)-based diffusion architectures. On the ImageNet-1K 256ร256 class-conditional generation task, it achieves 9.55ร and 25.39ร faster training than DiT at 400K and 7M steps, respectively, while significantly improving FID scores. Key contributions include: (i) a generic, structure-preserving token routing framework; (ii) cross-architecture compatibility across Transformers and SSMs; and (iii) a co-optimization mechanism that jointly enhances training efficiency and generation quality.
๐ Abstract
Diffusion models have emerged as the mainstream approach for visual generation. However, these models usually suffer from sample inefficiency and high training costs. This issue is particularly pronounced in the standard diffusion transformer architecture due to its quadratic complexity relative to input length. Recent works have addressed this by reducing the number of tokens processed in the model, often through masking. In contrast, this work aims to improve the training efficiency of the diffusion backbone by using predefined routes that store this information until it is reintroduced to deeper layers of the model, rather than discarding these tokens entirely. Further, we combine multiple routes and introduce an adapted auxiliary loss that accounts for all applied routes. Our method is not limited to the common transformer-based model - it can also be applied to state-space models. Unlike most current approaches, TREAD achieves this without architectural modifications. Finally, we show that our method reduces the computational cost and simultaneously boosts model performance on the standard benchmark ImageNet-1K 256 x 256 in class-conditional synthesis. Both of these benefits multiply to a convergence speedup of 9.55x at 400K training iterations compared to DiT and 25.39x compared to the best benchmark performance of DiT at 7M training iterations.