π€ AI Summary
Existing flatness measures struggle to accurately characterize the generalization ability of convolutional neural networks (CNNs), as they are often tailored to fully connected networks or overlook the geometric structure inherent to CNNs. This work addresses this limitation by focusing on a canonical CNN architecture comprising global average pooling followed by a linear classifier. For the first time, we derive a closed-form expression for the trace of the Hessian of the cross-entropy loss under this setting and propose a structure-aware relative flatness metric that explicitly accounts for the scale symmetries induced by convolution and pooling operations, as well as inter-filter interactions. Empirical evaluations demonstrate that the proposed metric effectively assesses and compares the generalization performance of CNN models, offering valuable theoretical guidance for architecture design and training strategies.
π Abstract
Flatness measures based on the spectrum or the trace of the Hessian of the loss are widely used as proxies for the generalization ability of deep networks. However, most existing definitions are either tailored to fully connected architectures, relying on stochastic estimators of the Hessian trace, or ignore the specific geometric structure of modern Convolutional Neural Networks (CNNs). In this work, we develop a flatness measure that is both exact and architecturally faithful for a broad and practically relevant class of CNNs. We first derive a closed-form expression for the trace of the Hessian of the cross-entropy loss with respect to convolutional kernels in networks that use global average pooling followed by a linear classifier. Building on this result, we then specialize the notion of relative flatness to convolutional layers and obtain a parameterization-aware flatness measure that properly accounts for the scaling symmetries and filter interactions induced by convolution and pooling. Finally, we empirically investigate the proposed measure on families of CNNs trained on standard image-classification benchmarks. The results obtained suggest that the proposed measure can serve as a robust tool to assess and compare the generalization performance of CNN models, and to guide the design of architecture and training choices in practice.