Training NTK to Generalize with KARE

📅 2025-05-16
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work challenges the paradigm of end-to-end training of deep neural networks (DNNs) by proposing **direct optimization of the Neural Tangent Kernel (NTK) to minimize generalization error**. Methodologically, it introduces a differentiable surrogate objective for generalization error—Kernel Alignment Risk Estimator (KARE)—and applies gradient descent to optimize NTK parameters explicitly, enabling kernel learning in the overparameterized regime. Theoretically, it integrates kernel alignment analysis to elucidate generalization mechanisms; empirically, it demonstrates on multiple synthetic and real-world benchmarks that KARE-optimized NTKs achieve stable generalization performance matching or significantly surpassing both the original DNN and the posterior (“after-kernel”) NTK. Crucially, this is the first work to elevate the NTK from a purely analytical tool to a trainable object, establishing a new paradigm for feature learning that bridges theoretical interpretability with strong empirical competitiveness.

Technology Category

Application Category

📝 Abstract
The performance of the data-dependent neural tangent kernel (NTK; Jacot et al. (2018)) associated with a trained deep neural network (DNN) often matches or exceeds that of the full network. This implies that DNN training via gradient descent implicitly performs kernel learning by optimizing the NTK. In this paper, we propose instead to optimize the NTK explicitly. Rather than minimizing empirical risk, we train the NTK to minimize its generalization error using the recently developed Kernel Alignment Risk Estimator (KARE; Jacot et al. (2020)). Our simulations and real data experiments show that NTKs trained with KARE consistently match or significantly outperform the original DNN and the DNN- induced NTK (the after-kernel). These results suggest that explicitly trained kernels can outperform traditional end-to-end DNN optimization in certain settings, challenging the conventional dominance of DNNs. We argue that explicit training of NTK is a form of over-parametrized feature learning.
Problem

Research questions and friction points this paper is trying to address.

Explicitly optimizes NTK to minimize generalization error
Compares KARE-trained NTK with DNN and after-kernel performance
Challenges traditional DNN dominance via over-parametrized feature learning
Innovation

Methods, ideas, or system contributions that make the work stand out.

Explicitly optimize NTK for better performance
Use KARE to minimize NTK generalization error
Outperform traditional DNN with trained NTK
🔎 Similar Papers
No similar papers found.