π€ AI Summary
Under tensor parallelism (TP), Multi-Head Latent Attention (MLA) struggles to leverage memory efficiency due to full-device loading of key-value (KV) caches, resulting in performance bottlenecks from inter-device communication and redundant cache replication. This work proposes TPLAβthe first method to jointly shard MLAβs latent representations and input dimensions across devices, enabling distributed KV cache compression and parallel attention computation while preserving full representational capacity. TPLA integrates low-rank compression, orthogonal transformations (Hadamard/PCA), and all-reduce optimization to eliminate cross-device interference. It requires no retraining and is plug-and-play for pre-trained MLA models. Evaluated on 32K-context workloads, TPLA accelerates prefill and decoding by 1.79Γ and 1.93Γ for DeepSeek-V3 and Kimi-K2, respectively, with zero degradation on LongBench and other long-context benchmarks.
π Abstract
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses key-value states into a low-rank latent vector, caching only this vector to reduce memory. In tensor parallelism (TP), however, attention heads are computed across multiple devices, and each device must load the full cache, eroding the advantage of MLA over Grouped Query Attention (GQA). We propose Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the latent representation and each head's input dimension across devices, performs attention independently per shard, and then combines results with an all-reduce. TPLA preserves the benefits of a compressed KV cache while unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in TPLA still leverages the full latent representation, maintaining stronger representational capacity. TPLA is drop-in compatible with models pre-trained using MLA: it supports MLA-style prefilling and enables efficient tensor-parallel decoding without retraining. Applying simple orthogonal transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further mitigates cross-shard interference, yielding minimal accuracy degradation. By reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x and 1.93x speedups, respectively, at a 32K-token context length while maintaining performance on commonsense and LongBench benchmarks. TPLA can be implemented with FlashAttention-3, enabling practical end-to-end acceleration.