SeerAttention: Learning Intrinsic Sparse Attention in Your LLMs

📅 2024-10-17
🏛️ arXiv.org
📈 Citations: 3
Influential: 0
📄 PDF
🤖 AI Summary
To address the inefficiency of large language models (LLMs) in long-context inference caused by the quadratic complexity of standard attention, this work proposes a learnable block-sparse attention mechanism. Methodologically, it introduces the first end-to-end framework for dynamically learning block-sparse attention patterns directly from LLM’s own data—eliminating hand-crafted templates. A lightweight gating module is designed, leveraging Q/K sequence pooling and linear projection, and integrated with a MoE-inspired architecture to jointly model gating scores. The approach further incorporates a block-sparse FlashAttention kernel and a lightweight self-distillation fine-tuning paradigm. Experiments demonstrate that our method significantly reduces prefill latency while maintaining higher accuracy than baseline dense attention and outperforming existing sparse attention methods in GPU inference throughput.

Technology Category

Application Category

📝 Abstract
Attention is the cornerstone of modern Large Language Models (LLMs). Yet its quadratic complexity hinders efficiency and scalability, especially for long-context processing. A promising approach is to leverage sparsity in attention. However, existing sparsity-based solutions predominantly rely on predefined patterns or heuristics at the attention head level, struggling to adapt dynamically to different contexts efficiently. We propose SeerAttention, a simple yet effective attention mechanism that directly learns the block-level attention sparsity from the LLM itself. Inspired by the gating mechanism in Mixture of Experts (MoE), SeerAttention augments the conventional attention with a learnable gate that selectively activates important blocks within the attention map. Specifically, the gate first pools the query (Q) and key (K) tensors along the sequence dimension and processes them through learnable linear layers. The resulting matrices are then multiplied together to produce the gating scores, which are used to predict block-level attention sparsity. Combined with our block-sparse FlashAttention kernel, SeerAttention can achieve significant speedup on GPUs. When applied to pre-trained LLMs, SeerAttention only requires training the gate parameters in a lightweight self-distillation manner, allowing rapid convergence. Our evaluation results demonstrate that SeerAttention achieves better model accuracy and lower latency for long-context pre-filling compared to prior methods.
Problem

Research questions and friction points this paper is trying to address.

Reduces quadratic complexity in LLMs
Learns block-level attention sparsity dynamically
Improves efficiency and scalability for long-context processing
Innovation

Methods, ideas, or system contributions that make the work stand out.

Learns block-level attention sparsity
Uses learnable gating mechanism
Integrates block-sparse FlashAttention kernel
🔎 Similar Papers
No similar papers found.