🤖 AI Summary
This work investigates the optimization dynamics and generalization behavior of Sign Gradient Descent (SignGD) for two-layer Transformers—comprising trainable softmax attention and a linear output layer—on noisy linearly separable data. Using rigorous dynamical systems analysis, we characterize, for the first time, a four-phase training trajectory under SignGD and prove its fast convergence, while establishing a significantly worse generalization error bound compared to standard gradient descent. Crucially, we identify the sign-based parameter update mechanism—not merely label noise—as the fundamental cause of its generalization deficit. Notably, Adam exhibits nearly identical optimization trajectories and generalization failure under the same setting. Extensive experiments on both synthetic and real-world datasets validate our theoretical predictions. Our findings offer new insights into optimizer-induced biases and the generalization gap in large language models.
📝 Abstract
The Adam optimizer is widely used for transformer optimization in practice, which makes understanding the underlying optimization mechanisms an important problem. However, due to the Adam's complexity, theoretical analysis of how it optimizes transformers remains a challenging task. Fortunately, Sign Gradient Descent (SignGD) serves as an effective surrogate for Adam. Despite its simplicity, theoretical understanding of how SignGD optimizes transformers still lags behind. In this work, we study how SignGD optimizes a two-layer transformer -- consisting of a softmax attention layer with trainable query-key parameterization followed by a linear layer -- on a linearly separable noisy dataset. We identify four stages in the training dynamics, each exhibiting intriguing behaviors. Based on the training dynamics, we prove the fast convergence but poor generalization of the learned transformer on the noisy dataset. We also show that Adam behaves similarly to SignGD in terms of both optimization and generalization in this setting. Additionally, we find that the poor generalization of SignGD is not solely due to data noise, suggesting that both SignGD and Adam requires high-quality data for real-world tasks. Finally, experiments on synthetic and real-world datasets empirically support our theoretical results.