Simple linear attention language models balance the recall-throughput tradeoff

πŸ“… 2024-02-28
πŸ›οΈ arXiv.org
πŸ“ˆ Citations: 45
✨ Influential: 10
πŸ“„ PDF
πŸ€– AI Summary
To address the low memory and computational efficiency of language models while preserving long-range recall capability, this paper proposes BASED: a novel architecture integrating linear and sliding-window attention. It is the first to systematically characterize the Pareto trade-off between state size and recall capacity. BASED introduces the first unified attention mechanism enabling continuous, fine-grained adjustment of the recall–memory balance. Furthermore, it incorporates I/O-aware memory access optimization to overcome the practical throughput bottleneck of linear attention. Experiments demonstrate that a 1.3B-parameter BASED model achieves a 6.22% absolute accuracy gain over Mamba on realistic recall-intensive tasks, attains 24Γ— higher token-generation throughput than FlashAttention-2 for 1024-token sequences, and matches the perplexity of state-of-the-art subquadratic-complexity models.

Technology Category

Application Category

πŸ“ Abstract
Recent work has shown that attention-based language models excel at recall, the ability to ground generations in tokens previously seen in context. However, the efficiency of attention-based models is bottle-necked during inference by the KV-cache's aggressive memory consumption. In this work, we explore whether we can improve language model efficiency (e.g. by reducing memory consumption) without compromising on recall. By applying experiments and theory to a broad set of architectures, we identify a key tradeoff between a model's state size and recall ability. We show that efficient alternatives to attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but struggle at recall. We propose BASED a simple architecture combining linear and sliding window attention. By varying BASED window size and linear attention feature dimension, we can dial the state size and traverse the pareto frontier of the recall-memory tradeoff curve, recovering the full quality of attention on one end and the small state size of attention-alternatives on the other. We train language models up to 1.3b parameters and show that BASED matches the strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them on real-world recall-intensive tasks by 6.22 accuracy points. Implementations of linear attention are often less efficient than optimized standard attention implementations. To make BASED competitive, we develop IO-aware algorithms that enable 24x higher throughput on language generation than FlashAttention-2, when generating 1024 tokens using 1.3b parameter models. Code for this work is provided at: https://github.com/HazyResearch/based.
Problem

Research questions and friction points this paper is trying to address.

Balancing recall and throughput in language models
Reducing memory consumption without compromising recall
Improving efficiency of linear attention implementations
Innovation

Methods, ideas, or system contributions that make the work stand out.

Combines linear and sliding window attention
Optimizes memory and recall tradeoff
IO-aware algorithms boost throughput significantly
πŸ”Ž Similar Papers
No similar papers found.