DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference

πŸ“… 2024-03-30
πŸ“ˆ Citations: 2
✨ Influential: 1
πŸ“„ PDF
πŸ€– AI Summary
Existing tree-structured LLM inference systems suffer from KV cache I/O redundancy and GPU load imbalance when processing shared-prefix tasks, leading to repeated data transfers between global/shared memory and suboptimal memory utilization. This work proposes Flash Tree-attention, a hardware-friendly attention algorithm. First, it introduces KV-Guided Groupingβ€”a novel mechanism that partitions KV caches according to shared-prefix awareness, eliminating redundant KV loading. Second, it designs Flattened Tree KV Splitting, enabling fine-grained, low-redundancy, and load-balanced KV cache partitioning and scheduling. Experiments across three representative tree-structured tasks demonstrate that our method reduces KV cache I/O by 73–99% and intermediate result I/O by up to 100%. End-to-end latency is accelerated by up to 2.23Γ—, while attention-specific latency improves by up to 3.59Γ—.

Technology Category

Application Category

πŸ“ Abstract
Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation. This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through KV-Guided Grouping, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose Flattened Tree KV Splitting, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59x speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.
Problem

Research questions and friction points this paper is trying to address.

Inefficient KV cache IO in tree-structured LLM inference.
Poor load balancing and GPU utilization in attention computation.
Redundant KV cache IO due to improper query partitioning.
Innovation

Methods, ideas, or system contributions that make the work stand out.

KV-Guided Grouping reduces redundant KV cache IO
Flattened Tree KV Splitting ensures balanced GPU utilization
DeFT achieves significant speedup in tree-based workloads
πŸ”Ž Similar Papers
No similar papers found.
J
Jinwei Yao
Westlake University, University of Illinois Urbana-Champaign
K
Kaiqi Chen
Zhejiang University
Kexun Zhang
Kexun Zhang
Carnegie Mellon University
Jiaxuan You
Jiaxuan You
Assistant Professor, UIUC CS
Foundation ModelsGNNLarge Language Models
B
Binhang Yuan
Hong Kong University of Science and Technology
Zeke Wang
Zeke Wang
Zhejiang University
Machine Learning SystemsSmartNICFPGAGPU
T
Tao Lin
Westlake University