🤖 AI Summary
This work addresses the performance bottleneck in deterministic attention mechanisms for large language models, where serial gradient accumulation during backpropagation leads to reduced throughput and suboptimal hardware utilization. The authors formulate deterministic attention backpropagation as a directed acyclic graph (DAG) scheduling problem and propose two complementary strategies: a descending Q-tile traversal order to minimize pipeline stalls in causal attention, and a theoretically optimal Shift Scheduling policy applicable to both full and causally masked attention. Integrated with the FlashAttention-3 architecture, these techniques achieve up to a 1.28× speedup in backward pass throughput on NVIDIA H800 GPUs, substantially narrowing the performance gap between deterministic and non-deterministic implementations.
📝 Abstract
Determinism is indispensable for reproducibility in large language model (LLM) training, yet it often exacts a steep performance cost. In widely used attention implementations such as FlashAttention-3, the deterministic backward pass can incur up to a 37.9% throughput reduction relative to its non-deterministic counterpart, primarily because gradient accumulation operations must be serialized to guarantee numerical consistency. This performance loss stems from suboptimal scheduling of compute and gradient-reduction phases, leading to significant hardware underutilization. To address this challenge, we formulate the backward pass of deterministic attention as a scheduling problem on a Directed Acyclic Graph (DAG) and derive schedules that minimize the critical path length. Building on this formulation, we present DASH (Deterministic Attention Scheduling for High-Throughput), which encapsulates two complementary scheduling strategies: (i) Descending Q-Tile Iteration, a reversed query-block traversal that shrinks pipeline stalls in causal attention, and (ii) Shift Scheduling, a theoretically optimal schedule within our DAG model that reduces pipeline stalls for both full and causal masks. Our empirical evaluations on NVIDIA H800 GPUs demonstrate that DASH narrows the performance gap of deterministic attention. The proposed strategies improve the throughput of the attention backward pass by up to 1.28$\times$ compared to the baseline, significantly advancing the efficiency of reproducible LLM training. Our code is open-sourced at https://github.com/SJTU-Liquid/deterministic-FA3.