Faster Video Diffusion with Trainable Sparse Attention

📅 2025-05-19
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Video diffusion Transformers (DiTs) suffer from quadratic computational complexity in 3D self-attention, severely hindering scalable deployment. To address this, we propose Trainable Sparse Attention (VSA), the first end-to-end differentiable sparse attention kernel featuring coarse-grained and fine-grained dynamic key-token selection—requiring no post-hoc analysis, hardware-friendly, and fully compatible with FlashAttention-3 (achieving up to 85% MFU). Integrated into the DiT architecture, VSA combines token pooling with block-wise computation. It reduces training FLOPs by 2.53× without degrading diffusion loss. On Wan-2.1, VSA accelerates inference by 6×, shortening end-to-end video generation latency from 31s to 18s while preserving generation quality.

Technology Category

Application Category

📝 Abstract
Scaling video diffusion transformers (DiTs) is limited by their quadratic 3D attention, even though most of the attention mass concentrates on a small subset of positions. We turn this observation into VSA, a trainable, hardware-efficient sparse attention that replaces full attention at emph{both} training and inference. In VSA, a lightweight coarse stage pools tokens into tiles and identifies high-weight emph{critical tokens}; a fine stage computes token-level attention only inside those tiles subjecting to block computing layout to ensure hard efficiency. This leads to a single differentiable kernel that trains end-to-end, requires no post-hoc profiling, and sustains 85% of FlashAttention3 MFU. We perform a large sweep of ablation studies and scaling-law experiments by pretraining DiTs from 60M to 1.4B parameters. VSA reaches a Pareto point that cuts training FLOPS by 2.53$ imes$ with no drop in diffusion loss. Retrofitting the open-source Wan-2.1 model speeds up attention time by 6$ imes$ and lowers end-to-end generation time from 31s to 18s with comparable quality. These results establish trainable sparse attention as a practical alternative to full attention and a key enabler for further scaling of video diffusion models.
Problem

Research questions and friction points this paper is trying to address.

Reducing quadratic 3D attention cost in video diffusion transformers
Enabling efficient trainable sparse attention for video generation
Improving speed and scalability of video diffusion models
Innovation

Methods, ideas, or system contributions that make the work stand out.

Trainable sparse attention replaces full attention
Lightweight coarse stage identifies critical tokens
Differentiable kernel ensures high hardware efficiency
🔎 Similar Papers
No similar papers found.