🤖 AI Summary
This work addresses the poorly understood mechanism of *m*-sharpness in Sharpness-Aware Minimization (SAM), particularly the counterintuitive observation that decreasing mini-batch size monotonically improves generalization. We propose a theory–algorithm co-design solution. First, within an extended stochastic differential equation (SDE) framework and via gradient noise structural analysis, we theoretically establish that mini-batch-induced stochastic noise exerts a variance regularization effect, inherently biasing optimization toward flatter minima. Second, leveraging this insight, we design Reweighted SAM: it approximates the generalization benefit of *m*-SAM via sharpness-weighted sampling—preserving full-batch perturbation parallelism while avoiding computational bottlenecks. Our theoretical analysis is comprehensively validated across diverse multi-task and multi-architecture benchmarks: Reweighted SAM achieves significant generalization gains and outperforms existing SAM variants in training efficiency.
📝 Abstract
Sharpness-aware minimization (SAM) has emerged as a highly effective technique for improving model generalization, but its underlying principles are not fully understood. We investigated the phenomenon known as m-sharpness, where the performance of SAM improves monotonically as the micro-batch size for computing perturbations decreases. Leveraging an extended Stochastic Differential Equation (SDE) framework, combined with an analysis of the structure of stochastic gradient noise (SGN), we precisely characterize the dynamics of various SAM variants. Our findings reveal that the stochastic noise introduced during SAM perturbations inherently induces a variance-based sharpness regularization effect. Motivated by our theoretical insights, we introduce Reweighted SAM, which employs sharpness-weighted sampling to mimic the generalization benefits of m-SAM while remaining parallelizable. Comprehensive experiments validate the effectiveness of our theoretical analysis and proposed method.