๐ค AI Summary
To address storage and memory bandwidth bottlenecks in large language model (LLM) inference caused by KV cache expansion along the feature dimension, this paper proposes the first learnable orthogonal projection method targeting feature-dimensional compression. Unlike mainstream approaches that compress only the first three axes (batch, sequence, head), our method enables layer- and head-wise adaptive compression ratios. It integrates Matryoshka distillation with a budget-driven adaptive search mechanism to overcome the performance collapse of PCA at low compression rates. Initialized via PCA and refined through end-to-end fine-tuning, the method supports plug-and-play deployment on pretrained LLMs without architectural modification. Evaluated on LLaMA2-7B and Mistral-7B, it achieves an average 60% KV cache reduction (up to 75%) while retaining over 90% of original task performance, significantly improving training data efficiency.
๐ Abstract
KV cache has become a de facto technique for the inference of large language models (LLMs), where tensors of shape (layer number, head number, sequence length, feature dimension) are introduced to cache historical information for self-attention. As the size of the model and data grows, the KV cache can quickly become a bottleneck within the system in both storage and memory transfer. To address this, prior studies usually focus on the first three axes of the cache tensors for compression. This paper supplements them, focusing on the feature dimension axis, by utilizing low-rank projection matrices to transform the cache features into spaces with reduced dimensions. We begin by investigating the canonical orthogonal projection method for data compression through principal component analysis (PCA). We observe the issue with PCA projection where significant performance degradation is observed at low compression rates. To bridge the gap, we propose to directly tune the orthogonal projection matrices with a distillation objective using an elaborate Matryoshka training strategy. After training, we adaptively search for the optimal compression rates for various layers and heads given varying compression budgets. Compared to previous works, our method can easily embrace pre-trained LLMs and hold a smooth tradeoff between performance and compression rate. We empirically witness the high data efficiency of our training procedure and find that our method can sustain over 90% performance with an average KV cache compression rate of 60% (and up to 75% in certain extreme scenarios) for popular LLMs like LLaMA2-7B-base and Mistral-7B-v0.3-base.