On the Optimization and Generalization of Two-layer Transformers with Sign Gradient Descent

📅 2024-10-07
🏛️ arXiv.org
📈 Citations: 3
Influential: 0
📄 PDF
🤖 AI Summary
This work investigates the optimization dynamics and generalization behavior of Sign Gradient Descent (SignGD) for two-layer Transformers—comprising trainable softmax attention and a linear output layer—on noisy linearly separable data. Using rigorous dynamical systems analysis, we characterize, for the first time, a four-phase training trajectory under SignGD and prove its fast convergence, while establishing a significantly worse generalization error bound compared to standard gradient descent. Crucially, we identify the sign-based parameter update mechanism—not merely label noise—as the fundamental cause of its generalization deficit. Notably, Adam exhibits nearly identical optimization trajectories and generalization failure under the same setting. Extensive experiments on both synthetic and real-world datasets validate our theoretical predictions. Our findings offer new insights into optimizer-induced biases and the generalization gap in large language models.

Technology Category

Application Category

📝 Abstract
The Adam optimizer is widely used for transformer optimization in practice, which makes understanding the underlying optimization mechanisms an important problem. However, due to the Adam's complexity, theoretical analysis of how it optimizes transformers remains a challenging task. Fortunately, Sign Gradient Descent (SignGD) serves as an effective surrogate for Adam. Despite its simplicity, theoretical understanding of how SignGD optimizes transformers still lags behind. In this work, we study how SignGD optimizes a two-layer transformer -- consisting of a softmax attention layer with trainable query-key parameterization followed by a linear layer -- on a linearly separable noisy dataset. We identify four stages in the training dynamics, each exhibiting intriguing behaviors. Based on the training dynamics, we prove the fast convergence but poor generalization of the learned transformer on the noisy dataset. We also show that Adam behaves similarly to SignGD in terms of both optimization and generalization in this setting. Additionally, we find that the poor generalization of SignGD is not solely due to data noise, suggesting that both SignGD and Adam requires high-quality data for real-world tasks. Finally, experiments on synthetic and real-world datasets empirically support our theoretical results.
Problem

Research questions and friction points this paper is trying to address.

Analyzes SignGD optimization in two-layer transformers.
Explores fast convergence but poor generalization on noisy data.
Compares SignGD and Adam in optimization and generalization.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Sign Gradient Descent optimizes two-layer transformers.
Identifies four stages in training dynamics.
SignGD and Adam require high-quality data.
🔎 Similar Papers
No similar papers found.