🤖 AI Summary
Standard softmax attention in Transformers incurs $O(N^2D)$ time complexity, severely hindering training and inference efficiency for large-scale models. This work proposes a high-performance linear attention (LA) implementation: we design novel forward and backward propagation algorithms and develop highly optimized CUDA kernels, achieving practical speedups while preserving the theoretical $O(ND^2)$ complexity. Our key innovations include a low-memory parallel computation scheme and kernel-level optimizations—such as fused memory access, register blocking, and warp-level reduction—to maximize hardware utilization. Evaluated on a 1.4B-parameter language model, our implementation achieves 3.3× faster training and inference and reduces GPU memory consumption by 3.6× compared to the state-of-the-art LA baselines, with no accuracy degradation across standard downstream evaluation tasks.
📝 Abstract
The original softmax-based attention mechanism (regular attention) in the extremely successful Transformer architecture computes attention between $N$ tokens, each embedded in a $D$-dimensional head, with a time complexity of $O(N^2D)$. Given the success of Transformers, improving their runtime during both training and inference is a popular research area. One such approach is the introduction of the linear attention (LA) mechanisms, which offers a linear time complexity of $O(ND^2)$ and have demonstrated comparable accuracy to regular attention. However, LA in practice lags behind its theoretical efficiency. We propose a novel method for LA's forward and backward passes, along with a highly-optimized CUDA implementation. Our approach outperforms the state-of-the-art by 3.3 times in speed and reduces memory consumption by 3.6 times. We validate these improvements in both single-layer and end-to-end settings by training a 1.4 billion parameter language model, which demonstrates similar expressivity to regular attention on major reasoning benchmarks.