S2-Attention: Hardware-Aware Context Sharding Among Attention Heads

๐Ÿ“… 2024-07-25
๐Ÿ“ˆ Citations: 0
โœจ Influential: 0
๐Ÿ“„ PDF
๐Ÿค– AI Summary
To address the low efficiency and poor stability of sparse attention in large language models, this paper proposes S2-Attentionโ€”the first hardware-aware sparse attention Triton kernel library. Methodologically, it introduces (1) fine-grained sparsity pattern customization per attention head and per context segment; (2) a novel heterogeneous context sharding mechanism, enabling each head to dynamically attend to complementary token subsets for full-context collaborative coverage; and (3) a hybrid sparse+dense architectural paradigm. S2-Attention is fully compatible with FlashAttention-2, integrates seamlessly with Megatron and vLLM, and optimizes memory I/O and parallel scheduling. Experiments demonstrate that, during training, it achieves up to 25.3ร— speedup over FlashAttention-2 and attains 100% retrieval accuracy on 128K-length contexts. For inference, it accelerates a 7B model by 4.5ร— while matching full-attention performance on downstream tasks.

Technology Category

Application Category

๐Ÿ“ Abstract
Sparse attention, which selectively attends to a subset of tokens in the context was supposed to be efficient. However, its theoretical reduction in FLOPs has rarely translated into wall-clock speed-up over its dense attention counterparts due to the lack of hardware-aware optimizations like FlashAttention. Meanwhile, it remains unclear whether sparse attention can maintain the model's quality at a scale of today's large language models (LLMs) and how. This paper presents Sparsely-Sharded(S2) Attention, a Triton library that provides kernel optimization for sparse attention customizable at both per-head and per-context-range levels. S2-Attention enables the exploration of novel and high-performance sparse attention techniques, which we demonstrate through extensive ablations across a wide range of sparse attention designs at various model scales. From these insights, we present several basic guidelines to design sparse attention that can achieve not only practical efficiency improvements, but also strong downstream performance. To achieve high parallelization and optimized memory IO, sparse attention should shard the context heterogeneously across attention heads, where each head attends to a different subset of tokens while collectively covering the full context. Meanwhile, we find hybrid architectures combining sparse and dense attention particularly beneficial in practice. S2-Attention achieves wall-clock speedup of 8.79X, 15.87X, 25.3X compared to the strong FlashAttention-2 baseline with strong downstream performance on-par with full attention and perfect retrieval performance at a 128k context length. At inference, for 7B models, our model, with the help of our S2-Attention kernel, achieves 4.5x speed-up compared to dense counterparts. S2-Attention is released with easy-to-customize APIs for direct usage in Megatron and vLLM.
Problem

Research questions and friction points this paper is trying to address.

Sparse Attention Mechanism
Efficiency Optimization
Stability Enhancement
Innovation

Methods, ideas, or system contributions that make the work stand out.

Sparsely-Sharded Attention
Sparse Attention Efficiency
S2-Attention Integration
๐Ÿ”Ž Similar Papers
No similar papers found.