🤖 AI Summary
Long-sequence LLM inference faces memory and computational bottlenecks due to KV cache explosion. This paper proposes TRIM-KV, a lightweight, learnable token-level cache management method. Its core is a differentiable temporal decay retention gate, trained end-to-end via frozen LLM distillation to estimate each token’s long-horizon importance during generation and dynamically evict low-value tokens. Unlike heuristic or rule-based approaches, TRIM-KV requires no manual design and autonomously discovers effective strategies—including sink-token emulation, sliding-window behavior, and semantic summarization—achieving both high cache efficiency and interpretability. On benchmarks spanning mathematical reasoning, long-context generation, and dialogue memory, TRIM-KV substantially outperforms quantization, offloading, and heuristic eviction baselines; remarkably, under severe memory constraints, it even surpasses full-cache models, demonstrating the implicit regularization effect of selective retention.
📝 Abstract
Memory and computation remain core bottlenecks in long-horizon LLM inference due to the quadratic cost of self-attention and the ever-growing key-value (KV) cache. Existing strategies for memory-bounded inference, such as quantization, offloading, or heuristic KV eviction, either incur high orchestration costs or rely on unreliable attention-based proxies of importance. We propose TRIM-KV, a novel approach that learns each token's intrinsic importance at creation time via a lightweight retention gate. Each gate predicts a scalar retention score that decays over time, reflecting the long-term utility of the token for a specific layer and head. Tokens with low scores are evicted when the memory budget is exceeded, ensuring that the cache always contains the most critical tokens. TRIM-KV is trained efficiently through distillation from a frozen LLM combined with a capacity loss, requiring only gate fine-tuning and adding negligible inference overhead. Across mathematical reasoning (GSM8K, MATH-500, AIME24), procedural generation (LongProc), conversational long-memory benchmarks (LongMemEval), and long-context understanding (LongBench and SCBench), TRIM-KV consistently outperforms strong eviction and learnable retrieval baselines, especially in low-memory regimes. Remarkably, it even surpasses full-cache models in some settings, showing that selective retention can serve as a form of regularization, suppressing noise from uninformative tokens. Qualitative analyses further reveal that learned retention scores align with human intuition, naturally recovering heuristics such as sink tokens, sliding windows, and gist compression without explicit design. Beyond efficiency, retention scores provide insights into layer- and head-specific roles, suggesting a new path toward LLM interpretability.