🤖 AI Summary
This work addresses the quadratic computational complexity of self-attention in Transformers when processing extremely long contexts, which scales as $O(n^2 d)$. To mitigate this bottleneck, the authors propose Sparse Feature Attention (SFA), the first systematic exploration of sparsity along the feature dimension: queries and keys are represented via $k$-sparse coding, substantially reducing computational overhead while preserving high-dimensional representational capacity. An I/O-aware FlashSFA kernel is introduced to efficiently implement sparse overlap computation and optimize KV cache usage. Experiments demonstrate that SFA matches the accuracy of dense baselines in pretraining GPT-2 and Qwen3, achieving up to 2.5× speedup, nearly 50% reduction in FLOPs and KV cache memory, and maintaining strong retrieval accuracy and robustness on long-context tasks.
📝 Abstract
Scaling Transformers to ultra-long contexts is bottlenecked by the $O(n^2 d)$ cost of self-attention. Existing methods reduce this cost along the sequence axis through local windows, kernel approximations, or token-level sparsity, but these approaches consistently degrade accuracy. In this paper, we instead explore an orthogonal axis: feature sparsity. We propose Sparse Feature Attention (SFA), where queries and keys are represented as $k$-sparse codes that preserve high-dimensional expressivity while reducing the cost of attention from $Θ(n^2 d)$ to $Θ(n^2 k^2/d)$. To make this efficient at scale, we introduce FlashSFA, an IO-aware kernel that extends FlashAttention to operate directly on sparse overlaps without materializing dense score matrices. Across GPT-2 and Qwen3 pretraining, SFA matches dense baselines while improving speed by up to $2.5\times$ and reducing FLOPs and KV-cache by nearly 50\%. On synthetic and downstream benchmarks, SFA preserves retrieval accuracy and robustness at long contexts, outperforming short-embedding baselines that collapse feature diversity. These results establish feature-level sparsity as a complementary and underexplored axis for efficient attention, enabling Transformers to scale to orders-of-magnitude longer contexts with minimal quality loss. Code is available at https://github.com/YannX1e/Sparse-Feature-Attention.