🤖 AI Summary
To address KV cache redundancy in Transformer inference, this work proposes a token-level attention role identification method that classifies each token as global, local, or sliding-window and dynamically determines its KV cache retention duration. Crucially, token role assignment is formulated as a learnable discrete search space, enabling, for the first time, end-to-end joint optimization of cache structure and model weights. Leveraging a one-shot neural architecture search framework, we introduce differentiable attention masks to jointly learn optimal token-type assignments during training or fine-tuning—eliminating hand-crafted eviction heuristics. Evaluated on both from-scratch training and large language model fine-tuning, our approach reduces KV cache memory footprint by 40% on average while significantly lowering inference latency, with no degradation in model accuracy.
📝 Abstract
We present Neural Attention Search (NAtS), a framework that automatically evaluates the importance of each token within a sequence and determines if the corresponding token can be dropped after several steps. This approach can efficiently reduce the KV cache sizes required by transformer-based models during inference and thus reduce inference costs. In this paper, we design a search space that contains three token types: (i) Global Tokens will be preserved and queried by all the following tokens. (ii) Local Tokens survive until the next global token appears. (iii) Sliding Window Tokens have an impact on the inference of a fixed size of the next following tokens. Similar to the One-Shot Neural Architecture Search approach, this token-type information can be learned jointly with the architecture weights via a learnable attention mask. Experiments on both training a new transformer from scratch and fine-tuning existing large language models show that NAtS can efficiently reduce the KV cache size required for the models while maintaining the models' performance.