Tree Attention: Topology-aware Decoding for Long-Context Attention on GPU clusters

πŸ“… 2024-08-07
πŸ›οΈ arXiv.org
πŸ“ˆ Citations: 2
✨ Influential: 1
πŸ“„ PDF
πŸ€– AI Summary
To address high inter-device communication overhead, excessive peak GPU memory consumption, and poor scalability in long-context LLM inference on GPU clusters, this paper proposes Tree Attentionβ€”a tree-reduction-based parallel attention mechanism. It is the first to model sequence-axis reduction as a topology-aware tree structure, enabling communication-efficient and memory-friendly cross-GPU collaborative decoding. The method integrates tree-reduction algorithms, fine-grained multi-GPU scheduling, and hardware-aware adaptation across heterogeneous accelerators (H100, MI300x, RTX 4090), along with a lightweight synchronization protocol. Theoretical analysis shows its communication complexity strictly dominates that of Ring Attention. Experiments on Llama 3.1-8B demonstrate up to 8Γ— higher decoding throughput (4Γ— real-world speedup), significantly reduced inter-GPU communication volume, and a 50% reduction in peak GPU memory usage.

Technology Category

Application Category

πŸ“ Abstract
Our formulation reveals that the reduction across the sequence axis can be efficiently computed in parallel through a tree reduction. Our algorithm, called Tree Attention, for parallelizing exact attention computation across multiple GPUs enables cross-device decoding to be performed asymptotically faster (up to 8x faster in our experiments) than state-of-the-art approaches such as Ring Attention, while also requiring significantly less communication volume and incurring 2x less peak memory. We demonstrate that Tree Attention speeds up decoding up to 4x on Llama 3.1-8B and can be applied to a variety of hardware and networking setups such as H100 DGX nodes, AMD MI300x nodes, and PCIe connected NVIDIA RTX 4090s. Our code is publicly available here: https://github.com/Zyphra/tree_attention
Problem

Research questions and friction points this paper is trying to address.

Parallelizing exact attention computation
Reducing communication volume in GPU clusters
Speeding up decoding on various hardware setups
Innovation

Methods, ideas, or system contributions that make the work stand out.

Tree Attention parallelizes exact attention
Reduces communication volume significantly
Speeds up decoding on multiple hardware setups
πŸ”Ž Similar Papers
No similar papers found.