🤖 AI Summary
This work addresses the computational inefficiency of standard attention mechanisms in long-context reasoning, where processing costs grow linearly with context length due to exhaustive traversal of cached key-value pairs. The authors propose a lightweight neural approach that replaces the original attention computation with constant-time forward propagation. Their method employs a dual-network architecture—comprising a target network and a normalization score network—to accurately approximate full attention outputs without dependence on cache size. Additionally, they introduce a compact, layer-wise and head-wise regression model coupled with a non-uniform capacity allocation strategy to optimize approximation fidelity across transformer layers. Experiments on models ranging from 1.7B to 8B parameters and five long-context benchmarks demonstrate that the proposed method substantially reduces computational overhead while preserving semantic consistency with full attention and maintaining controllable approximation error.
📝 Abstract
Evaluating softmax attention over a fixed long context requires reading every cached key-value pair for each new query token. For a given context (a book, a manual, a legal corpus) the attention output is a deterministic function of the query. We propose Nectar, which fits a compact neural network to this function for queries drawn from a task-relevant distribution. Nectar fits two networks per layer and KV-head: a target network that predicts the attention output and a score network that predicts the log-normalizer. The pair plugs into the standard masked self-attention at inference time, replacing the $O(n)$ attention over the cache with a forward pass whose cost does not depend on $n$. Each module carries on the order of $|θ|$ parameters per layer and KV-head, typically much smaller than the $2nd$ KV-cache footprint at the same granularity. We report experiments on models from 1.7B to 8B parameters across five long-context datasets. The approximation error tracks the next-token accuracy gap to full attention, and allocating capacity non-uniformly across layers reduces that gap in our ablation. Beyond this analysis of metrics, we check that the text generations (following a question prompt) of a model equipped with a Nectar module match in semantic content those obtained by giving the same model access to the full cache.