🤖 AI Summary
In conventional causal attention, query-key-value (QKV) projections are static and fixed per position, with keys encoding only historical context—limiting modeling of long-range dependencies. This paper proposes CASTLE, a novel autoregressive attention mechanism that dynamically updates key vectors at each position during decoding while strictly preserving causal constraints, thereby implicitly incorporating lookahead information. Its core contribution is a differentiable, parallelizable dynamic key update pathway, derived via mathematical equivalence to yield an efficient implementation matching standard attention’s computational complexity. Experiments demonstrate that CASTLE consistently reduces validation perplexity across multiple language modeling benchmarks, outperforming strong baselines across parameter scales (64M–1.3B), and further improves performance on downstream tasks.
📝 Abstract
In standard causal attention, each token's query, key, and value (QKV) are static and encode only preceding context. We introduce CAuSal aTtention with Lookahead kEys (CASTLE), an attention mechanism that continually updates each token's keys as the context unfolds. We term these updated keys lookahead keys because they belong to earlier positions yet integrate information from tokens that appear later relative to those positions, while strictly preserving the autoregressive property. Although the mechanism appears sequential, we derive a mathematical equivalence that avoids explicitly materializing lookahead keys at each position and enables efficient parallel training. On language modeling benchmarks, CASTLE consistently outperforms standard causal attention across model scales, reducing validation perplexity and improving performance on a range of downstream tasks.