🤖 AI Summary
In training long-context large language models, the core attention computation—softmax(QKᵀ)V—is co-located with other layers, causing severe load imbalance and tail latency under data and pipeline parallelism due to its quadratic computational complexity.
Method: We propose Core Attention Decoupling (CAD), which abstracts the parameter-free, stateless attention computation as schedulable tasks and offloads them to a dedicated GPU server pool. CAD integrates token-level fine-grained task partitioning, attention splitting, ping-pong execution, and in-place computation to fully overlap computation with communication and optimize memory usage.
Results: Evaluated on 512 H200 GPUs, CAD supports 512K-token contexts end-to-end, achieving a 1.35× throughput improvement over baseline while attaining near-perfect computational and memory load balancing across devices.
📝 Abstract
We present core attention disaggregation (CAD), a technique that improves long-context large language model training by decoupling the core attention computation, softmax(QK^T)V, from the rest of the model and executing it on a separate pool of devices. In existing systems, core attention is colocated with other layers; at long context lengths, its quadratic compute growth compared to the near-linear growth of other components causes load imbalance and stragglers across data and pipeline parallel groups. CAD is enabled by two observations. First, core attention is stateless: it has no trainable parameters and only minimal transient data, so balancing reduces to scheduling compute-bound tasks. Second, it is composable: modern attention kernels retain high efficiency when processing fused batches of token-level shards with arbitrary lengths. CAD partitions core attention into token-level tasks and dispatches them to dedicated attention servers, which dynamically rebatch tasks to equalize compute without sacrificing kernel efficiency. We implement CAD in a system called DistCA, which uses a ping-pong execution scheme to fully overlap communication with computation and in-place execution on attention servers to reduce memory use. On 512 H200 GPUs and context lengths up to 512k tokens, DistCA improves end-to-end training throughput by up to 1.35x, eliminates data and pipeline parallel stragglers, and achieves near-perfect compute and memory balance.