🤖 AI Summary
This work addresses the significant performance degradation of autoregressive speculative decoding under prompt perturbations and long-context scenarios—a phenomenon whose root cause has long remained unclear. The study identifies and characterizes a previously unreported issue termed “attention drift,” wherein the model’s attention progressively shifts from the original prompt toward its own generated tokens as speculation steps increase. To mitigate this effect, the authors propose two normalization strategies: Post-norm and per-step RMSNorm, both of which effectively suppress attention drift. Empirical results demonstrate that the proposed approach doubles the accepted token length under prompt perturbations, improves performance by 1.18× in long-context tasks, achieves an average gain of 1.10× across seven benchmarks, and substantially enhances depth-wise generalization between training and inference.
📝 Abstract
Speculative decoding accelerates LLM inference by drafting future tokens with a small model, but drafter models degrade sharply under template perturbation and long-context inputs. We identify a previously-unreported phenomenon we call \textbf{attention drift}: as the drafter generates successive tokens within a speculation chain, attention progressively moves from the prompt onto its own recently-generated tokens. We observe this across both \emph{EAGLE3} drafters and \emph{MTP heads}, suggesting drift is a property of drafter designs. We trace this to the un-normalized residual path between chain steps: the drafter's hidden state magnitude grows monotonically with chain depth, which exhibits dynamics consistent with additional pre-norm transformer layers stacked on the target rather than as a standalone autoregressive predictor. In order to limit the growth, we propose two architectural changes: Post-norm on the drafter hidden states and per-hidden-state RMSNorm after capturing target hidden states. Our interventions improve acceptance length over the current leading model, pre-norm EAGLE3, by up to $2\times$ under template perturbation, $1.18\times$ on long-context tasks, and $1.10\times$ on seven standard benchmarks spanning multi-turn chat, math, and coding. Our changes also allow shorter train-time-test depths to generalize over longer drafting sequences.