🤖 AI Summary
Bayesian model averaging (BMA) suffers from limited generalization when the posterior distribution is overly sharp, as commonly induced by standard approximate inference methods—e.g., mean-field variational inference—which fail to capture flat, well-generalizing regions of the loss landscape.
Method: We propose Flatness-Promoting BMA (FP-BMA), the first Bayesian training objective that explicitly models and optimizes posterior flatness. FP-BMA integrates a curvature-sensitive prior with KL regularization and leverages geometric analysis of the loss landscape to enforce flatness-aware posterior inference. We further extend it to a flatness-aware Bayesian transfer learning framework.
Results: On multiple benchmark datasets, FP-BMA significantly improves BMA test accuracy over standard baselines. Empirical results consistently demonstrate that flattening the posterior yields substantial and robust generalization gains. FP-BMA establishes a new paradigm for Bayesian deep learning by directly linking posterior geometry to generalization performance.
📝 Abstract
Bayesian neural networks (BNNs) estimate the posterior distribution of model parameters and utilize posterior samples for Bayesian Model Aver- aging (BMA) in prediction. However, despite the crucial role of flatness in the loss landscape in improving the generalization of neural networks, its impact on BMA has been largely overlooked. In this work, we explore how posterior flatness influences BMA generalization and empirically demonstrate that (1) most approximate Bayesian inference methods fail to yield a flat posterior and (2) BMA predictions, without considering posterior flatness, are less effective at improving generalization. To address this, we propose Flat Posterior-aware Bayesian Model Averaging (FP-BMA), a novel training objective that explicitly encourages flat posteriors in a principled Bayesian manner. We also introduce a Flat Posterior-aware Bayesian Transfer Learning scheme that enhances generalization in downstream tasks. Empirically, we show that FP-BMA successfully captures flat posteriors, improving generalization performance.