🤖 AI Summary
Transformer self-attention incurs high inter-device communication overhead and poor hardware scalability for long sequences, severely limiting distributed training efficiency. To address this, we propose the first exact, approximation-free 2D tensor-parallel self-attention architecture with zero additional computational or memory overhead. Our method jointly partitions the attention computation along both the query (Q) and key-value (KV) dimensions, enabling coordinated scheduling across devices. It integrates optimized inter-device communication primitives and leverages multi-node GPU clusters (A100/H100) for collaborative computation. The approach achieves linear scalability in both training and inference. On 64 A100 GPUs (16 nodes) and 64 H100 GPUs (64 nodes), it accelerates end-to-end throughput by 5.0× and 9.4× over Ring Attention, respectively, while significantly improving communication efficiency and horizontal scalability.
📝 Abstract
Transformer-based models have emerged as a leading architecture for natural language processing, natural language generation, and image generation tasks. A fundamental element of the transformer architecture is self-attention, which allows the model to capture intricate dependencies within the data. However, the self-attention mechanism also incurs significant computational and memory costs, particularly for long sequences. In this paper, we introduce ATTENTION2D, a novel approach that exploits parallelism along two dimensions - query and key/value - of the self-attention operation. This method enables efficient distribution and parallelization of computations across multiple devices. Our approach facilitates asymptotically faster training and inference phases compared to previous methods, without relying on approximations or incurring additional computational or memory overheads. Furthermore, unlike existing techniques that struggle to scale with an increasing number of processing units, our approach effectively scales with additional processing units. Our experimental results confirm the effectiveness of our method in improving communication efficiency and scalability. Compared to Ring Attention, our approach demonstrated up to a 5x performance boost on a GPT-3-like model using 64 NVIDIA A100 GPUs across 16 nodes, and up to a 9.4x performance boost on 64 NVIDIA H100 GPUs across 64 nodes.