🤖 AI Summary
This work addresses the quadratic complexity of attention mechanisms that hinders efficient long-context reasoning, a limitation exacerbated by existing block-sparse approaches whose coarse granularity impedes effective sparsity. The authors propose an importance-guided online token reordering strategy that restructures the non-contiguous token loading process within FlashAttention, coupled with a dynamic early-stopping mechanism. This approach achieves substantially improved sparsity efficiency with minimal preprocessing overhead. Evaluated on Llama-3.1-8B with 128K context length, the method reduces mean squared error (MSE) by 3.82× at the same sparsity level and decreases prefill computation density by 3.31× under equivalent MSE, yielding a 7.51× speedup in attention computation and a 3.81× end-to-end acceleration while preserving model accuracy.
📝 Abstract
Attention scales quadratically with sequence length, fundamentally limiting long-context inference. Existing block-granularity sparsification can reduce latency, but coarse blocks impose an intrinsic sparsity ceiling, making further improvements difficult even with carefully engineered designs. We present S2O, which performs early stopping for sparse attention via online permutation. Inspired by virtual-to-physical address mapping in memory systems, S2O revisits and factorizes FlashAttention execution, enabling inference to load non-contiguous tokens rather than a contiguous span in the original order. Motivated by fine-grained structures in attention heatmaps, we transform explicit permutation into an online, index-guided, discrete loading policy; with extremely lightweight preprocessing and index-remapping overhead, it concentrates importance on a small set of high-priority blocks. Building on this importance-guided online permutation for loading, S2O further introduces an early-stopping rule: computation proceeds from high to low importance; once the current block score falls below a threshold, S2O terminates early and skips the remaining low-contribution blocks, thereby increasing effective sparsity and reducing computation under a controlled error budget.
As a result, S2O substantially raises the practical sparsity ceiling. On Llama-3.1-8B under a 128K context, S2O reduces single-operator MSE by 3.82$\times$ at matched sparsity, and reduces prefill compute density by 3.31$\times$ at matched MSE; meanwhile, it preserves end-to-end accuracy and achieves 7.51$\times$ attention and 3.81$\times$ end-to-end speedups.