🤖 AI Summary
This work addresses the limitation of conventional Transformer multi-head attention, which lacks explicit inter-head interaction and thereby constrains model expressiveness and training efficiency. The authors propose Multi-head Explicit Attention (MEA), a novel mechanism that introduces explicit cross-head interaction for the first time through head-wise linear combinations and head-level group normalization. Additionally, MEA incorporates low-rank virtual heads to compress key-value (KV) cache memory. The method significantly enhances training stability and convergence speed, achieving lower validation loss. It maintains near-lossless performance on knowledge-intensive and scientific reasoning tasks, with only a 3.59% drop in accuracy on olympiad-level mathematical reasoning, while reducing KV cache memory overhead by 50%.
📝 Abstract
In large language models built upon the Transformer architecture, recent studies have shown that inter-head interaction can enhance attention performance. Motivated by this, we propose Multi-head Explicit Attention (MEA), a simple yet effective attention variant that explicitly models cross-head interaction. MEA consists of two key components: a Head-level Linear Composition (HLC) module that separately applies learnable linear combinations to the key and value vectors across heads, thereby enabling rich inter-head communication; and a head-level Group Normalization layer that aligns the statistical properties of the recombined heads. MEA shows strong robustness in pretraining, which allows the use of larger learning rates that lead to faster convergence, ultimately resulting in lower validation loss and improved performance across a range of tasks. Furthermore, we explore the parameter efficiency of MEA by reducing the number of attention heads and leveraging HLC to reconstruct them using low-rank"virtual heads". This enables a practical key-value cache compression strategy that reduces KV-cache memory usage by 50% with negligible performance loss on knowledge-intensive and scientific reasoning tasks, and only a 3.59% accuracy drop for Olympiad-level mathematical benchmarks.