🤖 AI Summary
This work addresses the lack of a unified theory for deep learning generalization under overparameterization, feature learning, and noisy conditions. By analyzing the partitioning of output space induced by the empirical neural tangent kernel, the study reveals a mechanism wherein signal-direction errors decay rapidly while noise residuals are suppressed. It further demonstrates that small-batch SGD accumulates population-level signals through linear drift while attenuating individual memorization. For the first time, generalization is rigorously established in the full-feature-learning regime where the operator norm undergoes O(1) changes, providing a unified explanation for benign overfitting, double descent, implicit bias, and grokking. The paper introduces a population-risk objective that obviates validation sets and proposes an SNR-based preconditioner—requiring only one additional state vector atop Adam—that accelerates grokking by up to 5×, substantially mitigates memorization in PINNs and implicit neural representations, and enhances DPO fine-tuning under noisy preferences, reducing policy deviation by threefold.
📝 Abstract
We present a non-asymptotic theory of generalization in deep learning where the empirical neural tangent kernel partitions the output space. In directions corresponding to signal, error dissipates rapidly; in the vast orthogonal dimensions corresponding to noise, the kernel's near-zero eigenvalues trap residual error in a test-invisible reservoir. Within the signal channel, minibatch SGD ensures that coherent population signal accumulates via fast linear drift, while idiosyncratic memorization is suppressed into a slow, diffusive random walk. We prove generalization survives even when the kernel evolves $\mathcal{O}(1)$ in operator norm, the full feature-learning regime. This theory naturally explains disparate phenomena in deep learning theory, such as benign overfitting, double descent, implicit bias, and grokking. Lastly, we derive an exact population-risk objective from a single training run with no validation data, for any architecture, loss, or optimizer, and prove that it measures precisely the noise in the signal channel. This objective reduces in practice to an SNR preconditioner on top of Adam, adding one state vector at no extra cost; it accelerates grokking by $5 \times$, suppresses memorization in PINNs and implicit neural representations, and improves DPO fine-tuning under noisy preferences while staying $3 \times$ closer to the reference policy.