Flash Multi-Head Feed-Forward Network

📅 2025-12-07
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address two key challenges in Transformer multi-head feed-forward networks (MH-FFNs)—linear memory growth with increasing head count and imbalanced intermediate-to-head dimension ratios that hinder scalability and expressiveness—this paper proposes FlashMHF, an I/O-aware efficient MH-FFN implementation. Its core innovations include a fused online computation kernel inspired by FlashAttention, dynamically weighted parallel subnetworks, and SRAM-optimized scheduling, jointly integrated with SwiGLU activation to preserve dimensional balance. Evaluated on models ranging from 128M to 1.3B parameters, FlashMHF achieves significantly lower perplexity and higher downstream task accuracy compared to standard SwiGLU FFNs, reduces peak memory usage by 3–5×, and accelerates inference by up to 1.08×. To our knowledge, FlashMHF is the first MH-FFN architecture achieving high scalability, low computational overhead, and strong representational capacity simultaneously.

Technology Category

Application Category

📝 Abstract
We explore Multi-Head FFN (MH-FFN) as a replacement of FFN in the Transformer architecture, motivated by the structural similarity between single-head attention and FFN. While multi-head mechanisms enhance expressivity in attention, naively applying them to FFNs faces two challenges: memory consumption scaling with the head count, and an imbalanced ratio between the growing intermediate size and the fixed head dimension as models scale, which degrades scalability and expressive power. To address these challenges, we propose Flash Multi-Head FFN (FlashMHF), with two key innovations: an I/O-aware fused kernel computing outputs online in SRAM akin to FlashAttention, and a design using dynamically weighted parallel sub-networks to maintain a balanced ratio between intermediate and head dimensions. Validated on models from 128M to 1.3B parameters, FlashMHF consistently improves perplexity and downstream task accuracy over SwiGLU FFNs, while reducing peak memory usage by 3-5x and accelerating inference by up to 1.08x. Our work establishes the multi-head design as a superior architectural principle for FFNs, presenting FlashMHF as a powerful, efficient, and scalable alternative to FFNs in Transformers.
Problem

Research questions and friction points this paper is trying to address.

Replacing FFN with multi-head design faces memory and scalability issues
Proposes FlashMHF with fused kernel and balanced dimension design
Improves performance while reducing memory and accelerating inference
Innovation

Methods, ideas, or system contributions that make the work stand out.

FlashMHF uses I/O-aware fused kernel for SRAM computation
It employs dynamically weighted parallel sub-networks for balanced dimensions
FlashMHF reduces memory usage and accelerates inference in Transformers
🔎 Similar Papers
No similar papers found.