AdaSplash-2: Faster Differentiable Sparse Attention

📅 2026-04-16
📈 Citations: 0
Influential: 0
📄 PDF

career value

224K/year
🤖 AI Summary
This work addresses the high computational overhead of α-entmax sparse attention, which stems from the iterative computation of the normalization constant τ and hinders its applicability to long-context scenarios. To overcome this limitation, the authors propose AdaSplash-2, a method that constructs a coarse-grained histogram of attention scores online and caches it in on-chip SRAM to provide a high-accuracy initial estimate for τ, thereby reducing the required iterations to just one or two. Additionally, AdaSplash-2 incorporates a sparsity-aware GPU kernel that skips zero-value blocks to accelerate both forward and backward passes. At moderate to high sparsity levels (>60%), AdaSplash-2 achieves per-step training speeds comparable to or exceeding those of FlashAttention-2, matches softmax-based baselines on downstream tasks with short contexts, and demonstrates significantly improved performance in long-context settings.

Technology Category

Application Category

📝 Abstract
Sparse attention has been proposed as a way to alleviate the quadratic cost of transformers, a central bottleneck in long-context training. A promising line of work is $α$-entmax attention, a differentiable sparse alternative to softmax that enables input-dependent sparsity yet has lagged behind softmax due to the computational overhead necessary to compute the normalizer $τ$. In this paper, we introduce AdaSplash-2, which addresses this limitation through a novel histogram-based initialization that reduces the number of iterations needed to compute $τ$ to typically 1--2. The key idea is to compute a coarse histogram of attention scores on the fly and store it in on-chip SRAM, yielding a more accurate initialization that enables fast forward and backward computation. Combined with a sparsity-aware GPU implementation that skips zero blocks with low overhead, AdaSplash-2 matches or improves per-step training time relative to FlashAttention-2 when block sparsity is moderate-to-high (e.g., $>$60\%), which often occurs at long-context lengths. On downstream tasks, models trained with our efficient $α$-entmax attention match softmax baselines at short-context lengths and achieve substantial gains in long-context settings.
Problem

Research questions and friction points this paper is trying to address.

sparse attention
α-entmax
computational overhead
normalizer τ
long-context training
Innovation

Methods, ideas, or system contributions that make the work stand out.

differentiable sparse attention
α-entmax
histogram-based initialization
on-chip SRAM
sparsity-aware GPU implementation
Nuno Gonçalves
Nuno Gonçalves
Institute for Systems and Robotics, University of Coimbra
BiometricsComputer VisionSteganographyRoboticsMedical Imaging
H
Hugo Pitorro
Instituto Superior Técnico, Universidade de Lisboa, Instituto de Telecomunicações
Vlad Niculae
Vlad Niculae
University of Amsterdam
Structured PredictionNatural Language ProcessingMachine Learning
E
Edoardo Ponti
University of Edinburgh
Lei Li
Lei Li
Associate Professor, School of Computer Science, Carnegie Mellon University
Machine LearningNatural Language ProcessingMachine TranslationLLMAI Drug Discovery
A
Andre Martins
Instituto Superior Técnico, Universidade de Lisboa, Instituto de Telecomunicações, TransPerfect, ELLIS Unit Lisbon
M
Marcos Treviso
Instituto Superior Técnico, Universidade de Lisboa, Instituto de Telecomunicações, ELLIS Unit Lisbon