🤖 AI Summary
This work addresses the significant computational overhead of the prefill phase in large language model inference under medium-to-short context lengths, where existing feed-forward network (FFN) sparsification methods struggle to balance acceleration and accuracy. The authors propose FastForward, a novel framework that introduces, for the first time, a predictive FFN sparsification mechanism tailored specifically for the prefill phase. It dynamically preserves critical computations and compensates for information loss through block-level context-aware neuron selection, a lightweight expert predictor, an error compensation network, and a layer-wise sparsity scheduling algorithm. Evaluated on 8B-scale models including LLaMA and Qwen, FastForward achieves up to 1.45× speedup under compute-bound conditions at 50% FFN sparsity, substantially reduces first-token latency, and incurs less than 6% accuracy degradation.
📝 Abstract
The prefill stage of large language model (LLM) inference is a key computational bottleneck for long-context workloads. At short-to-moderate context lengths (1K--16K tokens), Feed-Forward Networks (FFNs) dominate this cost, accounting for most of the total FLOPs. Existing FFN sparsification methods, designed for autoregressive decoding, fail to exploit the prefill stage's parallelism and often degrade accuracy. To address this, we introduce FastForward, a predictive sparsity framework that accelerates LLM prefill through block-wise, context-aware FFN sparsity. FastForward combines (1) a lightweight expert predictor to select high-importance neurons per block, (2) an error compensation network to correct sparsity-induced errors, and (3) a layer-wise sparsity scheduler to allocate compute based on token-mixing importance. Across LLaMA and Qwen models up to 8B parameters, FastForward delivers up to 1.45$\times$ compute-bound speedup at 50% FFN sparsity with $<$ 6% accuracy loss compared to the dense baseline on LongBench, substantially reducing Time-to-First-Token (TTFT) for efficient, long-context LLM inference on constrained hardware.