🤖 AI Summary
This work investigates the implicit bias differences between Adam and SGD in deep learning training. While SGD exhibits a strong simplicity bias—favoring linear decision boundaries—Adam demonstrably resists this bias, enabling the emergence of nonlinear features and closer approximation to the Bayes-optimal classifier. We analyze this phenomenon through both theoretical modeling—using population gradient dynamics—and empirical evaluation on two-layer ReLU networks trained on synthetic Gaussian mixture binary classification tasks. To our knowledge, this is the first study to jointly establish, via rigorous theory and controlled experiments, the mechanism by which Adam mitigates simplicity bias. Our results show that Adam achieves significantly better generalization than SGD both under in-distribution settings and certain distribution shifts, with decision boundaries consistently closer to the Bayes-optimal boundary. These findings provide a novel theoretical and empirical foundation for understanding the superior generalization performance of adaptive optimizers.
📝 Abstract
Adam is the de facto optimization algorithm for several deep learning applications, but an understanding of its implicit bias and how it differs from other algorithms, particularly standard first-order methods such as (stochastic) gradient descent (GD), remains limited. In practice, neural networks trained with SGD are known to exhibit simplicity bias -- a tendency to find simple solutions. In contrast, we show that Adam is more resistant to such simplicity bias. To demystify this phenomenon, in this paper, we investigate the differences in the implicit biases of Adam and GD when training two-layer ReLU neural networks on a binary classification task involving synthetic data with Gaussian clusters. We find that GD exhibits a simplicity bias, resulting in a linear decision boundary with a suboptimal margin, whereas Adam leads to much richer and more diverse features, producing a nonlinear boundary that is closer to the Bayes' optimal predictor. This richer decision boundary also allows Adam to achieve higher test accuracy both in-distribution and under certain distribution shifts. We theoretically prove these results by analyzing the population gradients. To corroborate our theoretical findings, we present empirical results showing that this property of Adam leads to superior generalization across datasets with spurious correlations where neural networks trained with SGD are known to show simplicity bias and don't generalize well under certain distributional shifts.