🤖 AI Summary
To address three key challenges in long-context reasoning with large language models—high autoregressive decoding latency, memory bloat from draft models, train-short/test-long distribution shift, and inefficiency of tree attention—this paper proposes an efficient speculative decoding framework. Our core contributions are: (1) a novel draft model with constant-size KV cache, drastically reducing memory overhead; (2) a scalable positional encoding scheme enabling seamless generalization from short-training sequences to long-context inference; and (3) a hybrid prefix-tree attention mechanism that preserves tree-mask accuracy while accelerating prefix computation. Evaluated on code completion, long-document summarization, and o1-style chain-of-thought reasoning, our method achieves an average end-to-end latency reduction of 38.6%, demonstrating the feasibility of high-throughput, low-latency speculative decoding for long-context scenarios.
📝 Abstract
Speculative decoding has become a promising technique to mitigate the high inference latency of autoregressive decoding in Large Language Models (LLMs). Despite its promise, the effective application of speculative decoding in LLMs still confronts three key challenges: the increasing memory demands of the draft model, the distribution shift between the short-training corpora and long-context inference, and inefficiencies in attention implementation. In this work, we enhance the performance of speculative decoding in long-context settings by addressing these challenges. First, we propose a memory-efficient draft model with a constant-sized Key-Value (KV) cache. Second, we introduce novel position indices for short-training data, enabling seamless adaptation from short-context training to long-context inference. Finally, we present an innovative attention aggregation method that combines fast implementations for prefix computation with standard attention for tree mask handling, effectively resolving the latency and memory inefficiencies of tree decoding. Our approach achieves strong results on various long-context tasks, including repository-level code completion, long-context summarization, and o1-like long reasoning tasks, demonstrating significant improvements in latency reduction. The code is available at https://github.com/sail-sg/LongSpec.