🤖 AI Summary
Distributed attention faces severe communication bottlenecks when scaling the context window of large language models; existing Ring-Attention suffers from high communication overhead, limiting its scalability. This paper proposes Mesh-Attention, the first two-dimensional tiled distributed attention paradigm, which maps attention computation onto matrix-shaped tiles across a GPU mesh and subsumes Ring-Attention as a special case. We design a greedy scheduling algorithm under communication constraints and theoretically prove its asymptotically reduced communication complexity. Experiments on 256 GPUs show that Mesh-Attention achieves an average 2.9× speedup over Ring-Attention (up to 3.4×) and reduces communication volume by 79.0% on average (up to 85.4%). The method significantly improves both scalability and efficiency for long-context LLM inference and training.
📝 Abstract
Distributed attention is a fundamental problem for scaling context window for Large Language Models (LLMs). The state-of-the-art method, Ring-Attention, suffers from scalability limitations due to its excessive communication traffic. This paper proposes a new distributed attention algorithm, Mesh-Attention, by rethinking the design space of distributed attention with a new matrix-based model. Our method assigns a two-dimensional tile -- rather than one-dimensional row or column -- of computation blocks to each GPU to achieve higher efficiency through lower communication-computation (CommCom) ratio. The general approach covers Ring-Attention as a special case, and allows the tuning of CommCom ratio with different tile shapes. Importantly, we propose a greedy algorithm that can efficiently search the scheduling space within the tile with restrictions that ensure efficient communication among GPUs. The theoretical analysis shows that Mesh-Attention leads to a much lower communication complexity and exhibits good scalability comparing to other current algorithms.
Our extensive experiment results show that Mesh-Attention can achieve up to 3.4x speedup (2.9x on average) and reduce the communication volume by up to 85.4% (79.0% on average) on 256 GPUs. Our scalability results further demonstrate that Mesh-Attention sustains superior performance as the system scales, substantially reducing overhead in large-scale deployments. The results convincingly confirm the advantage of Mesh-Attention.