🤖 AI Summary
To address the inefficiency, high energy consumption, and suboptimal generation quality inherent in autoregressive single-token decoding of large language models (LLMs), this paper proposes MTAD—a Multi-Token Adaptive Decoding framework integrating joint multi-token decoding with auxiliary model collaboration. Its core contributions are threefold: (1) the first high-fidelity approximation of the joint multi-token output distribution; (2) a verification-based multi-token speculative decoding mechanism with rigorous error-bound guarantees; and (3) synergistic integration of lightweight auxiliary model distillation, dual-model collaborative verification, and probability calibration. Evaluated on Llama-2 and OPT models (13B–70B), MTAD reduces perplexity by 21.2%, improves downstream task performance, achieves 1.42× speedup in inference latency, and reduces energy consumption by 1.54× compared to conventional speculative decoding—demonstrating simultaneous gains in both efficiency and generation quality.
📝 Abstract
Large language models (LLMs) have achieved remarkable success across diverse tasks, yet their inference processes are hindered by substantial time and energy demands due to single-token generation at each decoding step. While previous methods such as speculative decoding mitigate these inefficiencies by producing multiple tokens per step, each token is still generated by its single-token distribution, thereby enhancing speed without improving effectiveness. In contrast, our work simultaneously enhances inference speed and improves the output effectiveness. We consider multi-token joint decoding (MTJD), which generates multiple tokens from their joint distribution at each iteration, theoretically reducing perplexity and enhancing task performance. However, MTJD suffers from the high cost of sampling from the joint distribution of multiple tokens. Inspired by speculative decoding, we introduce multi-token assisted decoding (MTAD), a novel framework designed to accelerate MTJD. MTAD leverages a smaller auxiliary model to approximate the joint distribution of a larger model, incorporating a verification mechanism that not only ensures the accuracy of this approximation, but also improves the decoding efficiency over conventional speculative decoding. Theoretically, we demonstrate that MTAD closely approximates exact MTJD with bounded error. Empirical evaluations using Llama-2 and OPT models ranging from 13B to 70B parameters across various tasks reveal that MTAD reduces perplexity by 21.2% and improves downstream performance compared to standard single-token sampling. Furthermore, MTAD achieves a 1.42x speed-up and consumes 1.54x less energy than conventional speculative decoding methods. These results highlight MTAD's ability to make multi-token joint decoding both effective and efficient, promoting more sustainable and high-performance deployment of LLMs.