TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill & Decode Inference

πŸ“… 2025-08-21
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
Under tensor parallelism (TP), Multi-Head Latent Attention (MLA) struggles to leverage memory efficiency due to full-device loading of key-value (KV) caches, resulting in performance bottlenecks from inter-device communication and redundant cache replication. This work proposes TPLAβ€”the first method to jointly shard MLA’s latent representations and input dimensions across devices, enabling distributed KV cache compression and parallel attention computation while preserving full representational capacity. TPLA integrates low-rank compression, orthogonal transformations (Hadamard/PCA), and all-reduce optimization to eliminate cross-device interference. It requires no retraining and is plug-and-play for pre-trained MLA models. Evaluated on 32K-context workloads, TPLA accelerates prefill and decoding by 1.79Γ— and 1.93Γ— for DeepSeek-V3 and Kimi-K2, respectively, with zero degradation on LongBench and other long-context benchmarks.

Technology Category

Application Category

πŸ“ Abstract
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses key-value states into a low-rank latent vector, caching only this vector to reduce memory. In tensor parallelism (TP), however, attention heads are computed across multiple devices, and each device must load the full cache, eroding the advantage of MLA over Grouped Query Attention (GQA). We propose Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the latent representation and each head's input dimension across devices, performs attention independently per shard, and then combines results with an all-reduce. TPLA preserves the benefits of a compressed KV cache while unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in TPLA still leverages the full latent representation, maintaining stronger representational capacity. TPLA is drop-in compatible with models pre-trained using MLA: it supports MLA-style prefilling and enables efficient tensor-parallel decoding without retraining. Applying simple orthogonal transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further mitigates cross-shard interference, yielding minimal accuracy degradation. By reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x and 1.93x speedups, respectively, at a 32K-token context length while maintaining performance on commonsense and LongBench benchmarks. TPLA can be implemented with FlashAttention-3, enabling practical end-to-end acceleration.
Problem

Research questions and friction points this paper is trying to address.

Reducing KV cache memory in tensor parallelism without performance loss
Enabling efficient tensor-parallel decoding for MLA-pretrained models
Maintaining full representational capacity while partitioning attention heads
Innovation

Methods, ideas, or system contributions that make the work stand out.

Partitions latent representation across tensor parallel devices
Performs independent attention per shard with all-reduce combination
Uses orthogonal transforms to minimize cross-shard interference
πŸ”Ž Similar Papers
No similar papers found.
Xiaojuan Tang
Xiaojuan Tang
School of Intelligence Science and Technology, Peking University
machine learningreasoning
F
Fanxu Meng
Institute for Artificial Intelligence, Peking University
P
Pingzhi Tang
Institute for Artificial Intelligence, Peking University
Y
Yuxuan Wang
Institute for Artificial Intelligence, Peking University
Di Yin
Di Yin
Tencent
LLMNLPMLLM
Xing Sun
Xing Sun
Tencent Youtu Lab
LLMMLLMAgent
Muhan Zhang
Muhan Zhang
Peking University
Machine LearningGraph Neural NetworkLarge Language Models