🤖 AI Summary
This work resolves the long-standing debate on whether stochastic gradient descent (SGD) inherently favors flat minima during training. By constructing the first analytically tractable model, we causally demonstrate that SGD itself possesses no intrinsic bias toward flat solutions; rather, the sharpness of the converged minima is determined by the isotropy of label noise in the data. Specifically, isotropic noise leads to flat minima, whereas anisotropic noise can yield arbitrarily sharp minima. This finding is consistently validated across controlled experiments on diverse architectures—including MLPs, RNNs, and Transformers—highlighting that the data distribution, not the optimizer, governs the sharpness of minima. Our results offer a new perspective on the mechanisms underlying generalization in deep learning.
📝 Abstract
A large body of theory and empirical work hypothesizes a connection between the flatness of a neural network's loss landscape during training and its performance. However, there have been conceptually opposite pieces of evidence regarding when SGD prefers flatter or sharper solutions during training. In this work, we partially but causally clarify the flatness-seeking behavior of SGD by identifying and exactly solving an analytically solvable model that exhibits both flattening and sharpening behavior during training. In this model, the SGD training has no \textit{a priori} preference for flatness, but only a preference for minimal gradient fluctuations. This leads to the insight that, at least within this model, it is data distribution that uniquely determines the sharpness at convergence, and that a flat minimum is preferred if and only if the noise in the labels is isotropic across all output dimensions. When the noise in the labels is anisotropic, the model instead prefers sharpness and can converge to an arbitrarily sharp solution, depending on the imbalance in the noise in the labels spectrum. We reproduce this key insight in controlled settings with different model architectures such as MLP, RNN, and transformers.