🤖 AI Summary
This paper addresses the quantification of approximation error in Top-k sparse attention mechanisms. We propose the first unified framework that jointly models error at both the distributional level (via total variation distance ∥P−hat{P}∥₁) and the output level. Our key contributions are: (i) a precise characterization linking TV distance to softmax tail mass, yielding a sharp, non-asymptotic, Top-k–specific error bound; (ii) an output error decomposition formula driven by the mean difference between head and tail logits; and (iii) tighter bounds incorporating multiple gaps and block-wise variants. The analysis assumes i.i.d. Gaussian logits and leverages closed-form tail approximations. Empirical validation on BERT-base-uncased and synthetic data confirms the predicted k_ε/n scaling law: for a given TV error budget ε, Top-k sparsification reduces key-score computations by 2–4× on average, enabling certifiably accurate sparse attention.
📝 Abstract
We develop a unified mathematical framework for certified Top-$k$ attention truncation that quantifies approximation error at both the distribution and output levels. For a single attention distribution $P$ and its Top-$k$ truncation $hat P$, we show that the total-variation distance coincides with the discarded softmax tail mass and satisfies $mathrm{TV}(P,hat P)=1-e^{-mathrm{KL}(hat PVert P)}$, yielding sharp Top-$k$-specific bounds in place of generic inequalities. From this we derive non-asymptotic deterministic bounds -- from a single boundary gap through multi-gap and blockwise variants -- that control $mathrm{TV}(P,hat P)$ using only the ordered logits. Using an exact head-tail decomposition, we prove that the output error factorizes as $|mathrm{Attn}(q,K,V)-mathrm{Attn}_k(q,K,V)|_2=τ|μ_{mathrm{tail}}-μ_{mathrm{head}}|_2$ with $τ=mathrm{TV}(P,hat P)$, yielding a new head-tail diameter bound $|mathrm{Attn}(q,K,V)-mathrm{Attn}_k(q,K,V)|_2leτ,mathrm{diam}_{H,T}$ and refinements linking the error to $mathrm{Var}_P(V)$. Under an i.i.d. Gaussian score model $s_isimmathcal N(μ,σ^2)$ we derive closed-form tail masses and an asymptotic rule for the minimal $k_varepsilon$ ensuring $mathrm{TV}(P,hat P)levarepsilon$, namely $k_varepsilon/napproxΦ_c(σ+Φ^{-1}(varepsilon))$. Experiments on bert-base-uncased and synthetic logits confirm the predicted scaling of $k_varepsilon/n$ and show that certified Top-$k$ can reduce scored keys by 2-4$ imes$ on average while meeting the prescribed total-variation budget.