🤖 AI Summary
This work addresses the high computational overhead of α-entmax sparse attention, which stems from the iterative computation of the normalization constant τ and hinders its applicability to long-context scenarios. To overcome this limitation, the authors propose AdaSplash-2, a method that constructs a coarse-grained histogram of attention scores online and caches it in on-chip SRAM to provide a high-accuracy initial estimate for τ, thereby reducing the required iterations to just one or two. Additionally, AdaSplash-2 incorporates a sparsity-aware GPU kernel that skips zero-value blocks to accelerate both forward and backward passes. At moderate to high sparsity levels (>60%), AdaSplash-2 achieves per-step training speeds comparable to or exceeding those of FlashAttention-2, matches softmax-based baselines on downstream tasks with short contexts, and demonstrates significantly improved performance in long-context settings.
📝 Abstract
Sparse attention has been proposed as a way to alleviate the quadratic cost of transformers, a central bottleneck in long-context training. A promising line of work is $α$-entmax attention, a differentiable sparse alternative to softmax that enables input-dependent sparsity yet has lagged behind softmax due to the computational overhead necessary to compute the normalizer $τ$. In this paper, we introduce AdaSplash-2, which addresses this limitation through a novel histogram-based initialization that reduces the number of iterations needed to compute $τ$ to typically 1--2. The key idea is to compute a coarse histogram of attention scores on the fly and store it in on-chip SRAM, yielding a more accurate initialization that enables fast forward and backward computation. Combined with a sparsity-aware GPU implementation that skips zero blocks with low overhead, AdaSplash-2 matches or improves per-step training time relative to FlashAttention-2 when block sparsity is moderate-to-high (e.g., $>$60\%), which often occurs at long-context lengths. On downstream tasks, models trained with our efficient $α$-entmax attention match softmax baselines at short-context lengths and achieve substantial gains in long-context settings.