🤖 AI Summary
This work addresses the tendency of standard Softmax attention to disperse probability mass across irrelevant tokens in long-context scenarios, which impairs focus. While lowering the temperature sharpens attention, it often leads to vanishing gradients and hampers trainability. To overcome this trade-off, the authors propose LUCID attention, a novel mechanism that introduces a preconditioner based on exponentiated key-key similarities to minimize overlap among keys in a reproducing kernel Hilbert space (RKHS). This reshapes the attention distribution to enhance focus without suffering from the learning instability induced by low temperatures. LUCID maintains the same computational complexity as standard attention and, when applied to a 1-billion-parameter model trained with 128K context length, achieves up to 18% and 14% absolute improvements in multi-needle retrieval performance on BABILong and RULER benchmarks, respectively.
📝 Abstract
Softmax-based dot-product attention is a cornerstone of Transformer architectures, enabling remarkable capabilities such as in-context learning. However, as context lengths increase, a fundamental limitation of the softmax function emerges: it tends to diffuse probability mass to irrelevant tokens degrading performance in long-sequence scenarios. Furthermore, attempts to sharpen focus by lowering softmax temperature hinder learnability due to vanishing gradients. We introduce LUCID Attention, an architectural modification that applies a preconditioner to the attention probabilities. This preconditioner, derived from exponentiated key-key similarities, minimizes overlap between the keys in a Reproducing Kernel Hilbert Space, thus allowing the query to focus on important keys among large number of keys accurately with same computational complexity as standard attention. Additionally, LUCID's preconditioning-based approach to retrieval bypasses the need for low temperature and the learnability problems associated with it. We validate our approach by training ~1 billion parameter language models evaluated on up to 128K tokens. Our results demonstrate significant gains on long-context retrieval tasks, specifically retrieval tasks from BABILong, RULER, SCROLLS and LongBench. For instance, LUCID achieves up to 18% improvement in BABILong and 14% improvement in RULER multi-needle performance compared to standard attention.