First Attentions Last: Better Exploiting First Attentions for Efficient Transformer Training

๐Ÿ“… 2025-10-16
๐Ÿ“ˆ Citations: 0
โœจ Influential: 0
๐Ÿ“„ PDF
๐Ÿค– AI Summary
In large-scale Transformer training with tensor parallelism (TP), frequent all-reduce communications between multi-head attention (MHA) and MLP modules per layer impose a severe efficiency bottleneck. Method: We propose FAL, the first architecture that replaces conventional inter-module activation signals with the output of the initial attention layerโ€”enabling a restructured data flow that eliminates all-reduce operations between MHA and MLP within each layer and allows their full parallel execution. We further introduce FAL+, which incorporates normalized attention enhancement and output redirection to improve representational capacity without incurring additional communication overhead. Results: Experiments show FAL achieves up to 44% speedup in multi-GPU training and 1.18ร— higher single-GPU throughput versus baseline GPT, while attaining lower perplexity. FAL+ further reduces perplexity, demonstrating that communication elimination and model quality improvement are jointly attainable.

Technology Category

Application Category

๐Ÿ“ Abstract
As training billion-scale transformers becomes increasingly common, employing multiple distributed GPUs along with parallel training methods has become a standard practice. However, existing transformer designs suffer from significant communication overhead, especially in Tensor Parallelism (TP), where each block's MHA-MLP connection requires an all-reduce communication. Through our investigation, we show that the MHA-MLP connections can be bypassed for efficiency, while the attention output of the first layer can serve as an alternative signal for the bypassed connection. Motivated by the observations, we propose FAL (First Attentions Last), an efficient transformer architecture that redirects the first MHA output to the MLP inputs of the following layers, eliminating the per-block MHA-MLP connections. This removes the all-reduce communication and enables parallel execution of MHA and MLP on a single GPU. We also introduce FAL+, which adds the normalized first attention output to the MHA outputs of the following layers to augment the MLP input for the model quality. Our evaluation shows that FAL reduces multi-GPU training time by up to 44%, improves single-GPU throughput by up to 1.18x, and achieves better perplexity compared to the baseline GPT. FAL+ achieves even lower perplexity without increasing the training time than the baseline.
Problem

Research questions and friction points this paper is trying to address.

Reduces communication overhead in distributed transformer training
Eliminates per-block MHA-MLP connections via attention redirection
Enables parallel MHA-MLP execution while maintaining model quality
Innovation

Methods, ideas, or system contributions that make the work stand out.

Redirects first MHA output to subsequent MLP inputs
Eliminates per-block MHA-MLP connections and all-reduce communication
Enables parallel execution of MHA and MLP on single GPU
๐Ÿ”Ž Similar Papers
No similar papers found.
G
Gyudong Kim
Korea University
H
Hyukju Na
Korea University
J
Jin Hyeon Kim
Korea University
Hyunsung Jang
Hyunsung Jang
Yonsei University
Computer VisionDeep LearningMachine Learning
J
Jaemin Park
LIG Nex1 Co., Ltd.
J
Jaegi Hwang
LIG Nex1 Co., Ltd.
N
Namkoo Ha
LIG Nex1 Co., Ltd.
Seungryong Kim
Seungryong Kim
Associate Professor, KAIST
Computer VisionMachine Learning
Young Geun Kim
Young Geun Kim
Korea University
Operating SystemsComputer ArchitectureEmbedded SystemsEnergy/Power ManagementMobile/IoT Architecture