🤖 AI Summary
This work addresses the challenge of training instability in Transformers, which often leads to divergence and wasted computational resources without reliable early warning. The authors introduce Koopman spectral analysis into Transformer stability research for the first time, proposing a method that extracts dynamic modal features from inter-layer residual snapshots via a single forward pass at initialization. This enables the construction of a “near-identity spectral quality” metric that predicts divergence risk with high accuracy (AUROC = 0.995). Furthermore, they design a Koopman Spectral Shaping (KSS) mechanism to actively regulate the spectral distribution during training, thereby enhancing stability. The approach reduces the divergence rate from 66.7% to 12.5% under aggressive settings—without normalization layers and with high learning rates—and permits learning rate increases of 50%–150%.
📝 Abstract
Training divergence in transformers wastes compute, yet practitioners discover instability only after expensive runs begin. They therefore need an expected probability of failure for a transformer before training starts. Our study of Residual Koopman Spectral Profiling (RKSP) provides such an estimate. From a single forward pass at initialization, RKSP extracts Koopman spectral features by applying whitened dynamic mode decomposition to layer-wise residual snapshots. Our central diagnostic, the near-unit spectral mass, quantifies the fraction of modes concentrated near the unit circle, which captures instability risk. For predicting divergence across extensive configurations, this estimator achieves an AUROC of 0.995, outperforming the best gradient baseline. We further make this diagnostic actionable through Koopman Spectral Shaping (KSS), which reshapes spectra during training. We empirically validate that our method works in practice: RKSP predicts divergence at initialization, and when RKSP flags high risk, turning on KSS successfully prevents divergence. In the challenging high learning rate regime without normalization layers, KSS reduces the divergence rate from 66.7% to 12.5% and enables learning rates that are 50% to 150% higher. These findings generalize to WikiText-103 language modeling, vision transformers on CIFAR-10, and pretrained language models, including GPT-2 and LLaMA-2 up to 7B, as well as emerging architectures such as MoE, Mamba-style SSMs, and KAN.