🤖 AI Summary
Conventional mean-field (MF) theory significantly underestimates generalization improvement in finite-width neural networks after symmetry-breaking phase transitions, failing to capture feature learning (FL) dynamics.
Method: We develop a self-consistent statistical-physics-based MF theory that explicitly models the “self-amplifying input feature selection” mechanism, revealing the intrinsic alignment phase transition between learned representations and the target function.
Contribution/Results: Our theory transcends standard MF assumptions, enabling accurate prediction of the FL phase transition onset—governed jointly by noise strength and sample size—and quantitatively reproducing post-transition learning curves for the first time. Extensive experiments on two-layer nonlinear networks trained via stochastic gradient Langevin dynamics confirm strong agreement between theoretical predictions and empirical generalization behavior. This work establishes a new paradigm for understanding emergent training dynamics in finite-width deep learning.
📝 Abstract
Feature learning (FL), where neural networks adapt their internal representations during training, remains poorly understood. Using methods from statistical physics, we derive a tractable, self-consistent mean-field (MF) theory for the Bayesian posterior of two-layer non-linear networks trained with stochastic gradient Langevin dynamics (SGLD). At infinite width, this theory reduces to kernel ridge regression, but at finite width it predicts a symmetry breaking phase transition where networks abruptly align with target functions. While the basic MF theory provides theoretical insight into the emergence of FL in the finite-width regime, semi-quantitatively predicting the onset of FL with noise or sample size, it substantially underestimates the improvements in generalisation after the transition. We trace this discrepancy to a key mechanism absent from the plain MF description: extit{self-reinforcing input feature selection}. Incorporating this mechanism into the MF theory allows us to quantitatively match the learning curves of SGLD-trained networks and provides mechanistic insight into FL.