🤖 AI Summary
This work identifies a long-overlooked memory bottleneck: the QKV linear projections in large language model (LLM) multi-head attention incur substantial, underestimated GPU memory overhead. To address this, we propose Pointwise Approximate Matrix Multiplication (PAMM)—a lightweight, modular, low-rank tensor compression paradigm fully compatible with efficient attention implementations such as FlashAttention. PAMM reduces the GPU memory footprint of QKV projections to just 1/512 of the original, effectively eliminating their memory cost. Evaluated on Llama-2 and OPT models, PAMM achieves significant training memory reduction without degrading perplexity—often improving it. Our core contributions are threefold: (i) the first systematic quantification of the hidden memory bottleneck imposed by QKV projections; (ii) the first QKV-specific compression method achieving extreme compression ratios (512×) with zero accuracy or throughput degradation; and (iii) a plug-and-play design that integrates seamlessly into existing attention frameworks.
📝 Abstract
The Multi-Head Attention mechanism is central to LLM operation, and multiple works target its compute and memory efficiency during training. While most works focus on approximating the scaled dot product, the memory consumption of the linear projections that compute the $Q$, $K$, and $V$ tensors from the input $x$ is often overlooked. To address this, we propose Point-Approximate Matrix Multiplication (PAMM), a novel tensor compression technique that reduces memory consumption of the $Q,K,V$ projections in attention layers by a factor of up to $ imes 512$, effectively erasing their memory footprint, while achieving similar or better final perplexity. PAMM is fully composable with efficient attention techniques such as FlashAttention, making it a practical and complementary method for memory-efficient LLM training.