🤖 AI Summary
This work addresses the limitation of Bayesian Last Layer (BLL) methods, which perform Bayesian inference only on the final layer of a neural network and consequently underestimate uncertainty by neglecting epistemic uncertainty from preceding layers. To overcome this, the authors propose an improved approach that explicitly incorporates network-wide variability by projecting Neural Tangent Kernel (NTK) features into the last-layer feature space, provably enhancing posterior variance. A uniform subsampling strategy is introduced to mitigate computational overhead, accompanied by a theoretical bound on its approximation error. Empirical evaluations across UCI regression, contextual bandits, image classification, and out-of-distribution detection tasks demonstrate that the method significantly outperforms standard BLL and state-of-the-art baselines, achieving superior calibration and uncertainty estimation while maintaining computational efficiency.
📝 Abstract
Bayesian Last Layers (BLLs) provide a convenient and computationally efficient way to estimate uncertainty in neural networks. However, they underestimate epistemic uncertainty because they apply a Bayesian treatment only to the final layer, ignoring uncertainty induced by earlier layers. We propose a method that improves BLLs by leveraging a projection of Neural Tangent Kernel (NTK) features onto the space spanned by the last-layer features. This enables posterior inference that accounts for variability of the full network while retaining the low computational cost of inference of a standard BLL. We show that our method yields posterior variances that are provably greater or equal to those of a standard BLL, correcting its tendency to underestimate epistemic uncertainty. To further reduce computational cost, we introduce a uniform subsampling scheme for estimating the projection matrix and for posterior inference. We derive approximation bounds for both types of sub-sampling. Empirical evaluations on UCI regression, contextual bandits, image classification, and out-of-distribution detection tasks in image and tabular datasets, demonstrate improved calibration and uncertainty estimates compared to standard BLLs and competitive baselines, while reducing computational cost.