Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels

πŸ“… 2025-03-18
πŸ“ˆ Citations: 1
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
In long-context language modeling, linear RNNs (e.g., xLSTM) suffer from low arithmetic intensity and high I/O overhead due to memory explosion of intermediate states. To address this, we propose Tiled Flash Linear Attention (TFLA)β€”the first linear RNN kernel supporting arbitrary chunk sizes while achieving high arithmetic intensity. TFLA introduces intra-chunk sequence parallelism, tiled memory access patterns, GPU kernel-level optimizations, sigmoid gating, and matrix-based memory compression, yielding a lightweight mLSTM variant. Experiments demonstrate that TFLA-powered mLSTM consistently outperforms Flash Attention, Linear Attention, and Mamba in inference and training speed, establishing it as the state-of-the-art foundational operator for long-context modeling. It significantly improves training efficiency and scalability across diverse sequence lengths and hardware configurations.

Technology Category

Application Category

πŸ“ Abstract
Linear RNNs with gating recently demonstrated competitive performance compared to Transformers in language modeling. Although their linear compute scaling in sequence length offers theoretical runtime advantages over Transformers, realizing these benefits in practice requires optimized custom kernels, as Transformers rely on the highly efficient Flash Attention kernels. Leveraging the chunkwise-parallel formulation of linear RNNs, Flash Linear Attention (FLA) shows that linear RNN kernels are faster than Flash Attention, by parallelizing over chunks of the input sequence. However, since the chunk size of FLA is limited, many intermediate states must be materialized in GPU memory. This leads to low arithmetic intensity and causes high memory consumption and IO cost, especially for long-context pre-training. In this work, we present Tiled Flash Linear Attention (TFLA), a novel kernel algorithm for linear RNNs, that enables arbitrary large chunk sizes by introducing an additional level of sequence parallelization within each chunk. First, we apply TFLA to the xLSTM with matrix memory, the mLSTM. Second, we propose an mLSTM variant with sigmoid input gate and reduced computation for even faster kernel runtimes at equal language modeling performance. In our speed benchmarks, we show that our new mLSTM kernels based on TFLA outperform highly optimized Flash Attention, Linear Attention and Mamba kernels, setting a new state of the art for efficient long-context sequence modeling primitives.
Problem

Research questions and friction points this paper is trying to address.

Optimize linear RNN kernels for efficient long-context sequence modeling.
Reduce memory consumption and IO costs in linear RNNs.
Enhance runtime performance of linear RNNs compared to Transformers.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Tiled Flash Linear Attention for large chunks
Enhanced mLSTM with sigmoid input gate
Outperforms Flash Attention in speed benchmarks
πŸ”Ž Similar Papers
No similar papers found.