🤖 AI Summary
This work addresses the challenge of early detection of grokking—the abrupt emergence of generalization—before a significant rise in test accuracy. The authors propose a window-level localization method based on spectral diagnostics of empirical distributions: task-relevant observables are modeled as empirical distributions, embedded via Wasserstein or quantile representations to construct Hankel matrices, and analyzed using Dynamic Mode Decomposition (DMD) to examine reconstruction residuals, spectral properties, and effective rank. They demonstrate for the first time that reconstruction residuals serve as a leading indicator of grokking, closely tied to sensitivity under perturbations rather than merely reflecting changes in parameter norms. In modular addition Transformer experiments, the method achieves an AUROC of 0.93 in distinguishing grokking from non-grokking runs at the training-run level and issues true-positive alerts well in advance under controlled false-alarm rates.
📝 Abstract
In grokking, a model first fits the training data while test accuracy remains low, and only later begins to generalize. We ask whether this transition can be localized from observed training trajectories before the test accuracy rises, and formulate grokking transition localization as a diagnostic problem with an explicit threshold/FPR/lead-time trade-off. Task-dependent observables are summarized as empirical distributions, mapped to Wasserstein/quantile coordinates, and analyzed by Hankel dynamic mode decomposition (DMD); the resulting reconstruction residual, together with spectrum and effective rank, forms the diagnostic output. On held-out modular-addition Transformer runs, the residual achieves AUROC \(\approx \) 0.93 for grokking-vs-non-grokking discrimination at the run level; under a fixed sustained-threshold operating rule, true-positive alarms can precede onset, with lead time reported jointly with false-alarm rate and uncertainty intervals. Perturbation experiments show that, in the tested \(wd=1\) pool, high-residual windows exhibit about \(3\times\) larger short-horizon perturbation deviation than low-residual windows. In a same-data norm-window control, perturbation sensitivity aligns with the residual ordering rather than total-parameter-norm ordering, suggesting that the residual is not merely a total-norm proxy at the window level in the studied \(wd=1\) dynamics. Norm signals remain strong run-level regime indicators, and log-probability performs best among the observables tested under the current protocol. We position the residual as a window-level monitoring and localization signal in the studied modular-arithmetic Transformer settings, not a universal early-warning predictor or an intervention rule.