🤖 AI Summary
To address inference efficiency optimization for large language models (LLMs), this paper proposes a retraining-free hybrid attention distillation method: adaptively replacing selected layers in a pretrained Softmax Transformer with linear attention. The key contribution is a lightweight, KL-divergence-based layer importance estimation mechanism—requiring only a small set of generic texts—enabling data-efficient, task-agnostic importance scoring. We further introduce RADLADS, a distilled training pipeline integrating attention weight transfer, hidden-state alignment, and output distribution matching, augmented by lightweight fine-tuning. Experiments across multiple benchmarks demonstrate substantial improvements in inference speed and memory efficiency while preserving model performance. Our adaptive layer selection strategy outperforms both fixed-ratio interpolation and existing heuristic methods reliant on diagnostic datasets.
📝 Abstract
Distilling pretrained softmax attention Transformers into more efficient hybrid architectures that interleave softmax and linear attention layers is a promising approach for improving the inference efficiency of LLMs without requiring expensive pretraining from scratch. A critical factor in the conversion process is layer selection, i.e., deciding on which layers to convert to linear attention variants. This paper describes a simple and efficient recipe for layer selection that uses layer importance scores derived from a small amount of training on generic text data. Once the layers have been selected we use a recent pipeline for the distillation process itself citep[RADLADS;][]{goldstein2025radlads}, which consists of attention weight transfer, hidden state alignment, KL-based distribution matching, followed by a small amount of finetuning. We find that this approach is more effective than existing approaches for layer selection, including heuristics that uniformly interleave linear attentions based on a fixed ratio, as well as more involved approaches that rely on specialized diagnostic datasets.