Tight Generalization Error Bounds for Stochastic Gradient Descent in Non-convex Learning

📅 2025-06-23
📈 Citations: 0
Influential: 0
📄 PDF

career value

225K/year
🤖 AI Summary
This work addresses the loose generalization error bounds of SGD in non-convex learning. We propose Type II Perturbed SGD (T2pm-SGD), which refines analysis by decomposing the generalization error into trajectory and flatness terms. Theoretically, we tighten the trajectory term bound to $O(n^{-1})$ and, with optimally designed noise variance, achieve an overall generalization bound of $O(n^{-2/3})$, substantially improving upon existing results; the flatness term remains stable throughout iterations and is provably smaller. Methodologically, our framework unifies treatment of sub-Gaussian and bounded loss functions, leveraging perturbation mechanisms and rigorous error decomposition. Experiments on MNIST and CIFAR-10 validate both the tightness and practical efficacy of our theoretical bounds—marking the first work to simultaneously establish improved generalization guarantees under both canonical loss regimes.

Technology Category

Application Category

📝 Abstract
Stochastic Gradient Descent (SGD) is fundamental for training deep neural networks, especially in non-convex settings. Understanding SGD's generalization properties is crucial for ensuring robust model performance on unseen data. In this paper, we analyze the generalization error bounds of SGD for non-convex learning by introducing the Type II perturbed SGD (T2pm-SGD), which accommodates both sub-Gaussian and bounded loss functions. The generalization error bound is decomposed into two components: the trajectory term and the flatness term. Our analysis improves the trajectory term to $O(n^{-1})$, significantly enhancing the previous $O((nb)^{-1/2})$ bound for bounded losses, where n is the number of training samples and b is the batch size. By selecting an optimal variance for the perturbation noise, the overall bound is further refined to $O(n^{-2/3})$. For sub-Gaussian loss functions, a tighter trajectory term is also achieved. In both cases, the flatness term remains stable across iterations and is smaller than those reported in previous literature, which increase with iterations. This stability, ensured by T2pm-SGD, leads to tighter generalization error bounds for both loss function types. Our theoretical results are validated through extensive experiments on benchmark datasets, including MNIST and CIFAR-10, demonstrating the effectiveness of T2pm-SGD in establishing tighter generalization bounds.
Problem

Research questions and friction points this paper is trying to address.

Analyze generalization error bounds for SGD in non-convex learning
Improve trajectory term bound from O((nb)^{-1/2}) to O(n^{-1})
Stabilize flatness term across iterations for tighter bounds
Innovation

Methods, ideas, or system contributions that make the work stand out.

Introduces Type II perturbed SGD (T2pm-SGD)
Improves generalization error to O(n^{-2/3})
Stable flatness term across iterations