LoLCATs: On Low-Rank Linearizing of Large Language Models

📅 2024-10-14
🏛️ arXiv.org
📈 Citations: 4
Influential: 1
📄 PDF
🤖 AI Summary
To address the high computational cost, degraded quality, and scalability limitations (previously restricted to ≤7B models) in linearizing large language models (LLMs), this paper proposes a two-stage low-rank linearization framework. First, attention transfer aligns the softmax and linear attention outputs via mean squared error minimization; second, LoRA-based low-rank adaptation fine-tunes the linearized model. This approach is the first to enable efficient linearization of 70B- and 405B-scale LLMs, reducing the linearization quality gap by 77.8% and 78.1% on Llama 3.1, respectively. On 5-shot MMLU, it outperforms baselines by over 20 points while introducing only 0.2% additional parameters and 0.4% extra training tokens. Key innovations include a synergistic attention-transfer-and-low-rank-adaptation framework and a sub-quadratic-complexity linear attention mechanism—jointly breaking critical bottlenecks in model scale and inference efficiency.

Technology Category

Application Category

📝 Abstract
Recent works show we can linearize large language models (LLMs) -- swapping the quadratic attentions of popular Transformer-based LLMs with subquadratic analogs, such as linear attention -- avoiding the expensive pretraining costs. However, linearizing LLMs often significantly degrades model quality, still requires training over billions of tokens, and remains limited to smaller 1.3B to 7B LLMs. We thus propose Low-rank Linear Conversion via Attention Transfer (LoLCATs), a simple two-step method that improves LLM linearizing quality with orders of magnitudes less memory and compute. We base these steps on two findings. First, we can replace an LLM's softmax attentions with closely-approximating linear attentions, simply by training the linear attentions to match their softmax counterparts with an output MSE loss ("attention transfer"). Then, this enables adjusting for approximation errors and recovering LLM quality simply with low-rank adaptation (LoRA). LoLCATs significantly improves linearizing quality, training efficiency, and scalability. We significantly reduce the linearizing quality gap and produce state-of-the-art subquadratic LLMs from Llama 3 8B and Mistral 7B v0.1, leading to 20+ points of improvement on 5-shot MMLU. Furthermore, LoLCATs does so with only 0.2% of past methods' model parameters and 0.4% of their training tokens. Finally, we apply LoLCATs to create the first linearized 70B and 405B LLMs (50x larger than prior work). When compared with prior approaches under the same compute budgets, LoLCATs significantly improves linearizing quality, closing the gap between linearized and original Llama 3.1 70B and 405B LLMs by 77.8% and 78.1% on 5-shot MMLU.
Problem

Research questions and friction points this paper is trying to address.

Improves quality of linearized large language models (LLMs).
Reduces memory and compute requirements for LLM linearization.
Enables scaling to larger LLMs (e.g., 70B and 405B parameters).
Innovation

Methods, ideas, or system contributions that make the work stand out.

Linear attention replaces softmax in LLMs
Low-rank adaptation recovers model quality
Significantly reduces training tokens and parameters
🔎 Similar Papers
2024-05-17Conference on Empirical Methods in Natural Language ProcessingCitations: 7
M
Michael Zhang
Department of Computer Science, Stanford University
Simran Arora
Simran Arora
Computer Science, Stanford University
Computer ScienceAI Systems
Rahul Chalamala
Rahul Chalamala
Researcher at Together AI
computer sciencemachine learning
A
Alan Wu
California Institute of Technology
B
Benjamin F. Spector
Department of Computer Science, Stanford University
A
Aaryan Singhal
Department of Computer Science, Stanford University
K
Krithik Ramesh
Together AI, Massachusetts Institute of Technology
C
Christopher R'e
Department of Computer Science, Stanford University