🤖 AI Summary
To address the inference latency bottleneck imposed by autoregressive decoding in large language models (LLMs), this paper proposes Parallel Drafting (PARD), a novel speculative decoding framework that replaces the conventional autoregressive draft model with a parallel architecture capable of generating multiple tokens in a single forward pass. Key contributions include: (1) a novel conditional token-dropping training strategy that improves draft model training efficiency by 3×; (2) target-LLM-agnostic design, enabling a single draft model to generalize across an entire LLM family (e.g., LLaMA3.1 series); and (3) a co-designed draft–verify decoding mechanism integrated within an optimized inference framework. Evaluated on LLaMA3.1-8B, PARD achieves 311.5 tokens/s—4.08× faster than standard autoregressive inference—and significantly outperforms existing speculative decoding methods.
📝 Abstract
The autoregressive nature of large language models (LLMs) limits inference speed. Each forward pass generates only a single token and is often bottlenecked by memory bandwidth. Speculative decoding alleviates this issue using a draft-then-verify approach to accelerate token generation. However, the overhead introduced during the draft phase and the training cost of the draft model limit the efficiency and adaptability of speculative decoding. In this work, we introduce PARallel Draft (PARD), a novel speculative decoding method that enables low-cost adaptation of autoregressive draft models into parallel draft models. PARD enhances inference efficiency by predicting multiple future tokens in a single forward pass of the draft phase, and incorporates a conditional drop token method to accelerate training. Its target-independence property allows a single draft model to be applied to an entire family of different models, minimizing the adaptation cost. Our proposed conditional drop token method can improves draft model training efficiency by 3x. On our optimized inference framework, PARD accelerates LLaMA3.1-8B inference by 4.08x, achieving 311.5 tokens per second.