TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill & Decode Inference

๐Ÿ“… 2025-08-21
๐Ÿ“ˆ Citations: 0
โœจ Influential: 0
๐Ÿ“„ PDF

career value

211K/year
๐Ÿค– AI Summary
Under tensor parallelism (TP), Multi-Head Latent Attention (MLA) struggles to leverage memory efficiency due to full-device loading of key-value (KV) caches, resulting in performance bottlenecks from inter-device communication and redundant cache replication. This work proposes TPLAโ€”the first method to jointly shard MLAโ€™s latent representations and input dimensions across devices, enabling distributed KV cache compression and parallel attention computation while preserving full representational capacity. TPLA integrates low-rank compression, orthogonal transformations (Hadamard/PCA), and all-reduce optimization to eliminate cross-device interference. It requires no retraining and is plug-and-play for pre-trained MLA models. Evaluated on 32K-context workloads, TPLA accelerates prefill and decoding by 1.79ร— and 1.93ร— for DeepSeek-V3 and Kimi-K2, respectively, with zero degradation on LongBench and other long-context benchmarks.

Technology Category

Application Category

๐Ÿ“ Abstract
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses key-value states into a low-rank latent vector, caching only this vector to reduce memory. In tensor parallelism (TP), however, attention heads are computed across multiple devices, and each device must load the full cache, eroding the advantage of MLA over Grouped Query Attention (GQA). We propose Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the latent representation and each head's input dimension across devices, performs attention independently per shard, and then combines results with an all-reduce. TPLA preserves the benefits of a compressed KV cache while unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in TPLA still leverages the full latent representation, maintaining stronger representational capacity. TPLA is drop-in compatible with models pre-trained using MLA: it supports MLA-style prefilling and enables efficient tensor-parallel decoding without retraining. Applying simple orthogonal transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further mitigates cross-shard interference, yielding minimal accuracy degradation. By reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x and 1.93x speedups, respectively, at a 32K-token context length while maintaining performance on commonsense and LongBench benchmarks. TPLA can be implemented with FlashAttention-3, enabling practical end-to-end acceleration.
Problem

Research questions and friction points this paper is trying to address.

Reducing KV cache memory in tensor parallelism without performance loss
Enabling efficient tensor-parallel decoding for MLA-pretrained models
Maintaining full representational capacity while partitioning attention heads
Innovation

Methods, ideas, or system contributions that make the work stand out.

Partitions latent representation across tensor parallel devices
Performs independent attention per shard with all-reduce combination
Uses orthogonal transforms to minimize cross-shard interference
๐Ÿ”Ž Similar Papers
No similar papers found.