DualKV: Shared-Prompt Flash Attention for Efficient RL Training with Large Rollouts and Long Contexts

📅 2026-05-14
📈 Citations: 0
Influential: 0
📄 PDF

career value

240K/year
🤖 AI Summary
This work addresses the significant redundancy in standard FlashAttention during large-scale replay and long-context reinforcement learning training, where shared prompts are repeatedly computed and stored, dominating policy update overhead. The authors propose DualKV, the first approach to exploit the invariance of prompt representations under causal masking within the FlashAttention kernel. By integrating fused CUDA forward/backward kernels, a dual-region KV attention mechanism, and veRL-aware data repacking, DualKV processes shared prompts once without approximation, achieving exact equivalence with substantial speedup. The method introduces token-level computation reduction via a ρ factor, yielding 1.63–2.09× faster policy updates on Qwen3-8B, doubled micro-batch sizes, and an increase in model FLOPs utilization (MFU) from 36% to 76%. On a 30B MoE model, it achieves a 3.38× end-to-end speedup.
📝 Abstract
Modern RL post-training methods such as GRPO and DAPO train on $N$ response sequences of $R$ tokens sampled from a shared prompt of $P$ tokens, but standard FlashAttention replicates all $P$ prompt tokens $N$ times across both forward and backward passes -- duplicating compute and memory on identical hidden states. In large-rollout, long-context RL training ($N{\geq}16$, $P{\geq}8\text{K}$), this redundancy dominates the policy update cost. We observe that in decoder-only models, causal masking makes prompt representations invariant across sequences at every layer, so all per-token operations (norms, projections, MLP) and attention can process the prompt once -- a property not yet exploited at the kernel level for training. We propose \textbf{DualKV}, the first FlashAttention kernel variant that eliminates shared-prompt replication during RL training, via (1)~fused CUDA forward and backward kernels that iterate over two disjoint KV regions -- shared context and per-sequence response -- in a single kernel launch, and (2)~a data-pipeline redesign in veRL that repacks $N(P{+}R)$ tokens into $P{+}NR$ tokens per micro-batch, extending the token reduction from attention to the entire model by a factor $ρ= N(P{+}R)/(P{+}NR)$. DualKV is mathematically equivalent to standard attention and introduces no approximation. On Qwen3-8B GRPO training with 8$\times$H100 GPUs ($N{=}32$, 8K-context), DualKV achieves $1.63$--$2.09\times$ policy-update speedup, enables $2\times$ larger micro-batches, and raises MFU from $36\%$ to $76\%$. Similar gains hold for DAPO ($2.47\times$ speedup, $77\%$ MFU). At 30B MoE scale on 16$\times$H100, DualKV achieves $3.82\times$ policy-update and $3.38\times$ end-to-end step speedup over FlashAttention (which requires 4-way Ulysses sequence parallelism to avoid OOM).
Problem

Research questions and friction points this paper is trying to address.

Reinforcement Learning
Large Rollouts
Long Contexts
Shared Prompt
FlashAttention
Innovation

Methods, ideas, or system contributions that make the work stand out.

DualKV
FlashAttention
shared-prompt optimization
efficient RL training
long-context attention