🤖 AI Summary
Transformers rely heavily on adaptive optimizers like AdamW due to heavy-tailed gradient distributions, which cause instability when training with momentum-based SGD (mSGD).
Method: We propose the Deep Normalization Architecture (DNA), a systematic design that embeds multi-location normalization modules across Transformer layers to jointly regularize weights, activations, and Jacobian matrices—thereby substantially compressing the tail of the gradient distribution.
Contribution/Results: DNA enables stable and efficient training of Transformers using standard mSGD alone—eliminating the need for adaptive optimization. Theoretical analysis and extensive experiments demonstrate that DNA-equipped ViT and GPT-style models trained with mSGD match or surpass AdamW in both convergence speed and generalization performance on image classification and language modeling tasks. This work establishes a new paradigm for optimizer-agnostic Transformer training, offering a principled alternative to adaptive optimizers.
📝 Abstract
Transformers have become the de facto backbone of modern deep learning, yet their training typically demands an advanced optimizer with adaptive learning rate like AdamW, rather than a momentum SGDW (mSGDW). Previous works show that it is mainly due to a heavy-tailed distribution of the gradients. In this paper, we introduce a Deeply Normalized Transformer (DNT), which is meticulously engineered to overcome this limitation enabling seamless training with vanilla mSGDW while yielding comparable performance to the Transformers trained via AdamW. To be specific, in DNT, we strategically integrate normalization techniques at proper positions in the Transformers to effectively modulate the Jacobian matrices of each layer, balance the influence of weights, activations, and their interactions, and thus enable the distributions of gradients concentrated. We provide both theoretical justifications of the normalization technique used in our DNT and extensive empirical evaluation on two popular Transformer architectures to validate that: a) DNT outperforms its counterparts (ie, ViT and GPT), and b) DNT can be effectively trained with vanilla mSGDW.