Retrieval-Aware Distillation for Transformer-SSM Hybrids

πŸ“… 2026-02-11
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
This work addresses the underperformance of state space models (SSMs) in context retrieval tasks compared to Transformers, primarily due to their inability to replicate critical attention mechanisms. The authors propose a retrieval-aware distillation approach that converts a pretrained Transformer into a hybrid architecture by retaining only 2% of attention heads deemed essential for retrieval, while compressing the remaining components into a recurrent structure via knowledge distillation. Leveraging attention head ablation on synthetic retrieval tasks, sparse non-uniform attention layouts, and SSM state compression, the method recovers over 95% of the teacher model’s performance while reducing memory consumption by 5–6Γ—. This significantly narrows the performance gap between SSMs and Transformers in retrieval tasks and surpasses the prior limitation that hybrid models must retain at least 25% of attention heads.

Technology Category

Application Category

πŸ“ Abstract
State-space models (SSMs) offer efficient sequence modeling but lag behind Transformers on benchmarks that require in-context retrieval. Prior work links this gap to a small set of attention heads, termed Gather-and-Aggregate (G&A), which SSMs struggle to reproduce. We propose *retrieval-aware distillation*, which converts a pretrained Transformer into a hybrid student by preserving only these retrieval-critical heads and distilling the rest into recurrent heads. We identify the essential heads via ablation on a synthetic retrieval task, producing a hybrid with sparse, non-uniform attention placement. We show that preserving **just 2% of attention heads recovers over 95% of teacher performance on retrieval-heavy tasks** (10 heads in a 1B model), requiring far fewer heads than hybrids that retain at least 25%. We further find that large recurrent states often compensate for missing retrieval: once retrieval is handled by these heads, the SSM backbone can be simplified with limited loss, even with an $8\times$ reduction in state dimension. By reducing both the attention cache and the SSM state, the resulting hybrid is $5$--$6\times$ more memory-efficient than comparable hybrids, closing the Transformer--SSM gap at a fraction of the memory cost.
Problem

Research questions and friction points this paper is trying to address.

state-space models
in-context retrieval
attention heads
Transformer-SSM gap
retrieval performance
Innovation

Methods, ideas, or system contributions that make the work stand out.

retrieval-aware distillation
state-space models
attention sparsification
hybrid architecture
memory efficiency
πŸ”Ž Similar Papers
2024-08-19Neural Information Processing SystemsCitations: 3