🤖 AI Summary
Existing heuristic- or attention-score–based static ranking methods for KV cache compression in long-context LLM inference overlook the spatiotemporal dynamics of attention scores, leading to inaccurate critical token identification and performance degradation. This work proposes the first learning-based framework for critical token identification: (1) a lightweight convolutional network models the temporal evolution of attention scores; (2) a cross-token KV cache prefetching mechanism implicitly amortizes prediction overhead; and (3) dynamic, fine-grained KV cache compression is achieved. Evaluated on standard long-context benchmarks, our method achieves a 16× KV cache compression ratio while retaining over 98% of the original model’s accuracy—significantly outperforming current state-of-the-art approaches. The framework bridges the gap between attention dynamics modeling and efficient memory management, enabling scalable long-context inference without compromising model fidelity.
📝 Abstract
With the development of large language models (LLMs), efficient inference through Key-Value (KV) cache compression has attracted considerable attention, especially for long-context generation. To compress the KV cache, recent methods identify critical KV tokens through heuristic ranking with attention scores. However, these methods often struggle to accurately determine critical tokens as they neglect the extit{temporal patterns} in attention scores, resulting in a noticeable degradation in LLM performance. To address this challenge, we propose AttentionPredictor, which is the first learning-based critical token identification approach. Specifically, AttentionPredictor learns a lightweight convolution model to capture spatiotemporal patterns and predict the next-token attention score. An appealing feature of AttentionPredictor is that it accurately predicts the attention score while consuming negligible memory. Moreover, we propose a cross-token critical cache prefetching framework that hides the token estimation time overhead to accelerate the decoding stage. By retaining most of the attention information, AttentionPredictor achieves 16$ imes$ KV cache compression with comparable LLM performance, significantly outperforming the state-of-the-art.