π€ AI Summary
This work proposes a hardware-agnostic, exact attention mechanism that overcomes key limitations of existing acceleration methodsβsuch as Softmax approximation, reliance on Tensor Cores, or suboptimal FP32 throughput on long sequences due to serial computation. By reformulating online Softmax as a prefix scan over an associative monoid, the method achieves linear memory complexity and O(logβ―n) parallel depth. It requires no retraining, seamlessly replacing existing attention modules, and is implemented in Triton and CUDA C++ without Tensor Core dependency, employing an update strategy with provable error bounds. Experiments demonstrate 1.3β3.5Γ speedup over memory-efficient SDPA on A100 GPUs, 1.97β2.27Γ acceleration on BERT tasks, 1.5β1.6Γ improvement on Jetson TX2, and 17.8β20.2% higher throughput in LLaMA-13B offloading scenarios.
π Abstract
Existing attention accelerators often trade exact softmax semantics, depend on fused Tensor Core kernels, or incur sequential depth that limits FP32 throughput on long sequences. We present \textbf{ELSA}, an algorithmic reformulation of online softmax attention that (i)~preserves exact softmax semantics in real arithmetic with a \emph{provable} $\mathcal{O}(u\log n)$ FP32 relative error bound; (ii)~casts the online softmax update as a prefix scan over an associative monoid $(m,S,W)$, yielding $O(n)$ extra memory and $O(\log n)$ parallel depth; and (iii)~is Tensor-Core independent, implemented in Triton and CUDA C++, and deployable as a \emph{drop-in replacement} requiring no retraining or weight modification. Unlike FlashAttention-2/3, which rely on HMMA/GMMA Tensor Core instructions and provide no compatible FP32 path, ELSA operates identically on A100s and resource-constrained edge devices such as Jetson TX2 -- making it the only hardware-agnostic exact-attention kernel that reduces parallel depth to $O(\log n)$ at full precision. On A100 FP32 benchmarks (1K--16K tokens), ELSA delivers $1.3$--$3.5\times$ speedup over memory-efficient SDPA and $1.97$--$2.27\times$ on BERT; on Jetson TX2, ELSA achieves $1.5$--$1.6\times$ over Math (64--900 tokens), with $17.8$--$20.2\%$ throughput gains under LLaMA-13B offloading at $\ge$32K. In FP16, ELSA approaches hardware-fused baselines at long sequences while retaining full FP32 capability, offering a unified kernel for high-precision inference across platforms. Our code and implementation are available at https://github.com/ming053l/ELSA.