🤖 AI Summary
Neural networks trained via gradient descent exhibit a pervasive “simplicity bias”—preferentially learning low-complexity solutions (e.g., low-rank weight matrices, functions with few turning points, sparse convolutional kernels, or sparse attention heads) before progressively increasing complexity.
Method: We develop the first unified theoretical framework that attributes this bias across fully connected networks, CNNs, and Transformers to the dynamical switching of optimization trajectories between saddle points. Leveraging nonlinear dynamical systems theory, we decouple the effects of data distribution and initialization on learning plateaus, and rigorously analyze the dynamics via fixed-point theory and invariant manifold decomposition.
Results: We formally prove the mechanism underlying progressive complexity growth in all three architectures and quantitatively characterize its scaling laws—establishing precise conditions under which simplicity emerges first and how complexity evolves along training trajectories.
📝 Abstract
Neural networks trained with gradient descent often learn solutions of increasing complexity over time, a phenomenon known as simplicity bias. Despite being widely observed across architectures, existing theoretical treatments lack a unifying framework. We present a theoretical framework that explains a simplicity bias arising from saddle-to-saddle learning dynamics for a general class of neural networks, incorporating fully-connected, convolutional, and attention-based architectures. Here, simple means expressible with few hidden units, i.e., hidden neurons, convolutional kernels, or attention heads. Specifically, we show that linear networks learn solutions of increasing rank, ReLU networks learn solutions with an increasing number of kinks, convolutional networks learn solutions with an increasing number of convolutional kernels, and self-attention models learn solutions with an increasing number of attention heads. By analyzing fixed points, invariant manifolds, and dynamics of gradient descent learning, we show that saddle-to-saddle dynamics operates by iteratively evolving near an invariant manifold, approaching a saddle, and switching to another invariant manifold. Our analysis also illuminates the effects of data distribution and weight initialization on the duration and number of plateaus in learning, dissociating previously confounding factors. Overall, our theory offers a framework for understanding when and why gradient descent progressively learns increasingly complex solutions.