Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models

📅 2024-08-19
🏛️ Neural Information Processing Systems
📈 Citations: 3
Influential: 0
📄 PDF
🤖 AI Summary
Transformer models incur substantial inference overhead due to the quadratic complexity of self-attention, while emerging sub-quadratic architectures—such as State Space Models (SSMs)—offer computational efficiency but lag significantly in pretraining scale and resource investment compared to state-of-the-art Transformers. To bridge this gap, we propose MOHAWK: the first fine-grained knowledge distillation framework for transferring knowledge from Transformers to SSMs. MOHAWK enables high-fidelity knowledge transfer via progressive alignment of hybrid matrices, hidden states, and end-to-end predictions. The method is compatible with both Mamba-2 and Phi-family models. With only 3B tokens, Phi-Mamba surpasses all open-source non-Transformer baselines; Hybrid Phi-Mamba trained on 5B tokens achieves further gains. This work provides the first empirical validation that SSMs can effectively reuse Transformer pretraining investments—attaining SOTA performance using just 0.3% of the original training data—establishing a scalable distillation paradigm for sub-quadratic models.

Technology Category

Application Category

📝 Abstract
Transformer architectures have become a dominant paradigm for domains like language modeling but suffer in many inference settings due to their quadratic-time self-attention. Recently proposed subquadratic architectures, such as Mamba, have shown promise, but have been pretrained with substantially less computational resources than the strongest Transformer models. In this work, we present a method that is able to distill a pretrained Transformer architecture into alternative architectures such as state space models (SSMs). The key idea to our approach is that we can view both Transformers and SSMs as applying different forms of mixing matrices over the token sequences. We can thus progressively distill the Transformer architecture by matching different degrees of granularity in the SSM: first matching the mixing matrices themselves, then the hidden units at each block, and finally the end-to-end predictions. Our method, called MOHAWK, is able to distill a Mamba-2 variant based on the Phi-1.5 architecture (Phi-Mamba) using only 3B tokens and a hybrid version (Hybrid Phi-Mamba) using 5B tokens. Despite using less than 1% of the training data typically used to train models from scratch, Phi-Mamba boasts substantially stronger performance compared to all past open-source non-Transformer models. MOHAWK allows models like SSMs to leverage computational resources invested in training Transformer-based architectures, highlighting a new avenue for building such models.
Problem

Research questions and friction points this paper is trying to address.

Distill Transformer to subquadratic models
Match mixing matrices and hidden units
Enhance SSMs with less training data
Innovation

Methods, ideas, or system contributions that make the work stand out.

Distill Transformer to SSMs
Match mixing matrices progressively
Train with minimal tokens
🔎 Similar Papers
No similar papers found.
Aviv Bick
Aviv Bick
Ph.D. in Computer Science at Carnegie Mellon University
Kevin Y. Li
Kevin Y. Li
Carnegie Mellon University
deep learning
E
Eric P. Xing
Machine Learning Department, Carnegie Mellon University; MBZUAI
J
J. Zico Kolter
Machine Learning Department, Carnegie Mellon University
Albert Gu
Albert Gu
Carnegie Mellon University
Machine Learning