π€ AI Summary
Existing average sensitivity metrics struggle to characterize the robustness of large language models to input perturbations and fail to adequately capture their βjunta-likeβ input dependency. This work introduces noise stability into Transformer analysis as a novel measure of model simplicity and robustness. Through theoretical analysis of noise stability in single-layer attention and ReLU MLP blocks, and by modeling multi-layer signal propagation via covariance interval propagation, we devise a new regularization strategy for training. The proposed method accelerates training by approximately 35% on algorithmic tasks and 75% on next-token prediction tasks, while consistently promoting the emergence of grokking. Our findings establish a theoretical link between noise stability, model interpretability, and training dynamics.
π Abstract
Understanding simplicity biases in deep learning offers a promising path toward developing reliable AI. A common metric for this, inspired by Boolean function analysis, is average sensitivity, which captures a model's robustness to single-token perturbations. We argue that average sensitivity has two key limitations: it lacks a natural generalization to real-valued domains and fails to explain the"junta-like"input dependence we empirically observe in modern LLMs. To address these limitations, we propose noise stability as a more comprehensive simplicity metric. Noise stability expresses a model's robustness to correlated noise applied to all input coordinates simultaneously. We provide a theoretical analysis of noise stability for single-layer attention and ReLU MLP layers and tackle the multi-layer propagation problem with a covariance interval propagation approach. Building on this theory, we develop a practical noise stability regularization method. Experiments on algorithmic and next-token-prediction tasks show that our regularizer consistently catalyzes grokking and accelerates training by approximately $35\%$ and $75\%$ respectively. Our results sculpt a new connection between signal propagation in neural networks and interpretability, with noise stability emerging as a powerful tool for understanding and improving modern Transformers.