🤖 AI Summary
Modern non-ReLU Transformers face a fundamental trade-off among model quality, parameter count, and training simplicity when simultaneously enforcing sparsity in both feed-forward network (FFN) activations and attention mechanisms.
Method: We propose a hardware-friendly high-sparsity implementation: (i) a linear-time, sorting-free statistical top-k approximation algorithm; and (ii) a novel parameter reallocation mechanism that repurposes FFN weights and key embeddings to construct a lightweight activation predictor—without modifying pretraining paradigms, adding parameters, or incurring extra training overhead.
Results: Our method achieves 8% FFN activation density and caps the number of keys per attention head at 256 tokens, reducing FLOPs by 2.5× and accelerating CPU/GPU decoding by 1.79× and 1.40×, respectively, while preserving state-of-the-art benchmark performance. To our knowledge, this is the first work to achieve concurrent high sparsity in both FFN and attention pathways on modern architectures such as Gemma-2.
📝 Abstract
The discovery of the lazy neuron phenomenon in trained Transformers, where the vast majority of neurons in their feed-forward networks (FFN) are inactive for each token, has spurred tremendous interests in activation sparsity for enhancing large model efficiency. While notable progress has been made in translating such sparsity to wall-time benefits, modern Transformers have moved away from the ReLU activation function crucial to this phenomenon. Existing efforts on re-introducing activation sparsity often degrade model quality, increase parameter count, complicate or slow down training. Sparse attention, the application of sparse activation to the attention mechanism, often faces similar challenges. This paper introduces the Spark Transformer, a novel architecture that achieves a high level of activation sparsity in both FFN and the attention mechanism while maintaining model quality, parameter count, and standard training procedures. Our method realizes sparsity via top-k masking for explicit control over sparsity level. Crucially, we introduce statistical top-k, a hardware-accelerator-friendly, linear-time approximate algorithm that avoids costly sorting and mitigates significant training slowdown from standard top-$k$ operators. Furthermore, Spark Transformer reallocates existing FFN parameters and attention key embeddings to form a low-cost predictor for identifying activated entries. This design not only mitigates quality loss from enforced sparsity, but also enhances wall-time benefit. Pretrained with the Gemma-2 recipe, Spark Transformer demonstrates competitive performance on standard benchmarks while exhibiting significant sparsity: only 8% of FFN neurons are activated, and each token attends to a maximum of 256 tokens. This sparsity translates to a 2.5x reduction in FLOPs, leading to decoding wall-time speedups of up to 1.79x on CPU and 1.40x on GPU.