🤖 AI Summary
Bayesian neural networks (BNNs) commonly suffer from overconfidence, hyperparameter sensitivity, and posterior miscalibration due to double-counting of training data during inference. To address these issues—particularly in image recognition—we propose the first factor-graph-based message-passing BNN framework tailored for convolutional neural networks (CNNs). Our method introduces an approximate message-passing mechanism grounded in Bayesian inference, explicitly eliminating double-counting in parameter updates, thereby substantially mitigating overconfidence. On CIFAR-10, our approach achieves superior calibration compared to AdamW and IVON, attaining an extrapolation uncertainty calibration correlation coefficient of 0.9. Experiments on synthetic data confirm highly accurate posterior credible interval coverage. Furthermore, the framework scales to MLPs with up to 5.6 million parameters. This work establishes a scalable, theoretically consistent paradigm for reliable uncertainty quantification in deep learning.
📝 Abstract
Bayesian neural networks (BNNs) offer the potential for reliable uncertainty quantification and interpretability, which are critical for trustworthy AI in high-stakes domains. However, existing methods often struggle with issues such as overconfidence, hyperparameter sensitivity, and posterior collapse, leaving room for alternative approaches. In this work, we advance message passing (MP) for BNNs and present a novel framework that models the predictive posterior as a factor graph. To the best of our knowledge, our framework is the first MP method that handles convolutional neural networks and avoids double-counting training data, a limitation of previous MP methods that causes overconfidence. We evaluate our approach on CIFAR-10 with a convolutional neural network of roughly 890k parameters and find that it can compete with the SOTA baselines AdamW and IVON, even having an edge in terms of calibration. On synthetic data, we validate the uncertainty estimates and observe a strong correlation (0.9) between posterior credible intervals and its probability of covering the true data-generating function outside the training range. While our method scales to an MLP with 5.6 million parameters, further improvements are necessary to match the scale and performance of state-of-the-art variational inference methods.