FiRST: Finetuning Router-Selective Transformers for Input-Adaptive Latency Reduction

📅 2024-10-16
🏛️ arXiv.org
📈 Citations: 1
Influential: 0
📄 PDF
🤖 AI Summary
To address the high inference latency of autoregressive large language models (LLMs) on edge devices—caused by sequential, layer-by-layer decoding—this paper proposes an input-adaptive dynamic layer-skipping mechanism. Our method introduces a sequence-level, prompt-driven differentiable router that selects Transformer layers in a task- and input-aware manner while preserving KV cache integrity. Leveraging lightweight adapters and a plug-and-play architecture, it is compatible with arbitrary LLMs without structural modification. Evaluated across multiple benchmarks, our approach reduces inference latency by 35–52%, maintains or improves generation quality over the baseline model, and significantly outperforms early-exit and fixed-skipping baselines. To the best of our knowledge, this is the first low-latency autoregressive inference framework that jointly ensures KV cache compatibility, input adaptivity, and practical deployability on resource-constrained edge devices.

Technology Category

Application Category

📝 Abstract
Auto-regressive Large Language Models (LLMs) demonstrate remarkable performance across different domains such as vision and language processing. However, due to sequential processing through a stack of transformer layers, autoregressive decoding faces significant computation/latency challenges, particularly in resource-constrained environments like mobile and edge devices. Existing approaches in literature that aim to improve latency via skipping layers have two distinct flavors - 1) Early exit, and 2) Input-agnostic heuristics where tokens exit at pre-determined layers irrespective of input sequence. Both the above strategies have limitations - the former cannot be applied to handle KV Caching necessary for speed-ups in modern framework and the latter does not capture the variation in layer importance across tasks or more generally, across input sequences. To address both limitations, we propose FiRST, an algorithm that reduces inference latency by using layer-specific routers to select a subset of transformer layers adaptively for each input sequence - the prompt (during the prefill stage) decides which layers will be skipped during decoding. FiRST preserves compatibility with KV caching enabling faster inference while being quality-aware. FiRST is model-agnostic and can be easily enabled on any pre-trained LLM. Our approach reveals that input adaptivity is critical - indeed, different task-specific middle layers play a crucial role in evolving hidden representations depending on tasks. Extensive experiments show that FiRST significantly reduces latency while outperforming other layer selection strategies in quality metics. It retains competitive performance to base model (without layer skipping) and in some cases, even improves upon it. FiRST is thus a promising and efficient solution for LLM deployment in low-resource environments.
Problem

Research questions and friction points this paper is trying to address.

Reduces LLM inference latency via input-adaptive layer skipping
Maintains KV caching compatibility for faster autoregressive decoding
Improves efficiency without sacrificing model performance quality
Innovation

Methods, ideas, or system contributions that make the work stand out.

Uses layer-specific routers for adaptive layer selection
Preserves KV caching compatibility for faster inference
Model-agnostic and applicable to any pre-trained LLM
🔎 Similar Papers
No similar papers found.