π€ AI Summary
Existing tree-structured LLM inference systems suffer from KV cache I/O redundancy and GPU load imbalance when processing shared-prefix tasks, leading to repeated data transfers between global/shared memory and suboptimal memory utilization. This work proposes Flash Tree-attention, a hardware-friendly attention algorithm. First, it introduces KV-Guided Groupingβa novel mechanism that partitions KV caches according to shared-prefix awareness, eliminating redundant KV loading. Second, it designs Flattened Tree KV Splitting, enabling fine-grained, low-redundancy, and load-balanced KV cache partitioning and scheduling. Experiments across three representative tree-structured tasks demonstrate that our method reduces KV cache I/O by 73β99% and intermediate result I/O by up to 100%. End-to-end latency is accelerated by up to 2.23Γ, while attention-specific latency improves by up to 3.59Γ.
π Abstract
Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation. This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through KV-Guided Grouping, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose Flattened Tree KV Splitting, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59x speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.