🤖 AI Summary
Standard attention incurs quadratic computational and memory complexity, posing severe bottlenecks for long-context reasoning. To address this, we propose Block-Sparse FlashAttention (BSFA): a training-free, block-level sparsification method that computes query-key similarities exactly, then prunes attention blocks by comparing each block’s maximum similarity score against a calibrated threshold—retaining only the top-k most significant value blocks. BSFA introduces hierarchical head-wise threshold calibration and custom CUDA kernels to preserve fine-grained attention patterns. Evaluated on Llama-3.1-8B, BSFA achieves 1.10× inference speedup and 1.24× acceleration on retrieval tasks, outperforming existing sparse attention methods. Crucially, it maintains over 99% of the original model’s accuracy—matching or even exceeding baseline performance on certain tasks—while serving as a drop-in replacement for FlashAttention.
📝 Abstract
Modern large language models increasingly require long contexts for reasoning and multi-document tasks, but attention's quadratic complexity creates a severe computational bottleneck. We present Block-Sparse FlashAttention (BSFA), a drop-in replacement that accelerates long-context inference while preserving model quality. Unlike methods that predict importance before computing scores, BSFA computes exact query-key similarities to select the top-k most important value blocks for each query. By comparing per-block maximum scores against calibrated thresholds, we skip approximately 50% of the computation and memory transfers for pruned blocks. Our training-free approach requires only a one-time threshold calibration on a small dataset to learn the per-layer and per-head attention score distributions. We provide a CUDA kernel implementation that can be used as a drop-in replacement for FlashAttention. On Llama-3.1-8B, BSFA achieves up to 1.10x speedup on real-world reasoning benchmarks and up to 1.24x for needle-in-a-haystack retrieval tasks while maintaining above 99% baseline accuracy, with certain configurations even improving accuracy by focusing on the most relevant content, substantially outperforming existing sparse attention methods. The implementation is available at https://github.com/Danielohayon/Block-Sparse-Flash-Attention