MatryoshkaKV: Adaptive KV Compression via Trainable Orthogonal Projection

๐Ÿ“… 2024-10-16
๐Ÿ›๏ธ International Conference on Learning Representations
๐Ÿ“ˆ Citations: 2
โœจ Influential: 0
๐Ÿ“„ PDF
๐Ÿค– AI Summary
To address storage and memory bandwidth bottlenecks in large language model (LLM) inference caused by KV cache expansion along the feature dimension, this paper proposes the first learnable orthogonal projection method targeting feature-dimensional compression. Unlike mainstream approaches that compress only the first three axes (batch, sequence, head), our method enables layer- and head-wise adaptive compression ratios. It integrates Matryoshka distillation with a budget-driven adaptive search mechanism to overcome the performance collapse of PCA at low compression rates. Initialized via PCA and refined through end-to-end fine-tuning, the method supports plug-and-play deployment on pretrained LLMs without architectural modification. Evaluated on LLaMA2-7B and Mistral-7B, it achieves an average 60% KV cache reduction (up to 75%) while retaining over 90% of original task performance, significantly improving training data efficiency.

Technology Category

Application Category

๐Ÿ“ Abstract
KV cache has become a de facto technique for the inference of large language models (LLMs), where tensors of shape (layer number, head number, sequence length, feature dimension) are introduced to cache historical information for self-attention. As the size of the model and data grows, the KV cache can quickly become a bottleneck within the system in both storage and memory transfer. To address this, prior studies usually focus on the first three axes of the cache tensors for compression. This paper supplements them, focusing on the feature dimension axis, by utilizing low-rank projection matrices to transform the cache features into spaces with reduced dimensions. We begin by investigating the canonical orthogonal projection method for data compression through principal component analysis (PCA). We observe the issue with PCA projection where significant performance degradation is observed at low compression rates. To bridge the gap, we propose to directly tune the orthogonal projection matrices with a distillation objective using an elaborate Matryoshka training strategy. After training, we adaptively search for the optimal compression rates for various layers and heads given varying compression budgets. Compared to previous works, our method can easily embrace pre-trained LLMs and hold a smooth tradeoff between performance and compression rate. We empirically witness the high data efficiency of our training procedure and find that our method can sustain over 90% performance with an average KV cache compression rate of 60% (and up to 75% in certain extreme scenarios) for popular LLMs like LLaMA2-7B-base and Mistral-7B-v0.3-base.
Problem

Research questions and friction points this paper is trying to address.

Compress KV cache in LLMs to reduce storage and memory transfer bottlenecks
Optimize feature dimension compression via trainable orthogonal projection matrices
Achieve high performance retention (90%) with 60-75% KV cache compression
Innovation

Methods, ideas, or system contributions that make the work stand out.

Uses trainable orthogonal projection for KV compression
Adaptively searches optimal compression rates per layer
Maintains performance with 60-75% KV cache compression
๐Ÿ”Ž Similar Papers
No similar papers found.
B
Bokai Lin
Qing Yuan Research Institute, SEIEE, Shanghai Jiao Tong University
Z
Zihao Zeng
Qing Yuan Research Institute, SEIEE, Shanghai Jiao Tong University
Zipeng Xiao
Zipeng Xiao
Shanghai Jiao Tong University
Deep learning
Siqi Kou
Siqi Kou
Shanghai Jiaotong university
Machine Learning
Tianqi Hou
Tianqi Hou
Theory Lab, Central Research Institute, 2012 Labs, Huawei Technologies Co., Ltd.
statistical physicsmachine learning๏ผŒhigh-dimensional statisticsComputational Neuroscience
Xiaofeng Gao
Xiaofeng Gao
Shanghai Jiao Tong University
Data EngineeringNetwork Optimization
H
Hao Zhang
University of California, San Diego
Z
Zhijie Deng
Qing Yuan Research Institute, SEIEE, Shanghai Jiao Tong University