laplax -- Laplace Approximations with JAX

📅 2025-07-22
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the inefficiency and poor scalability of weight uncertainty quantification in Bayesian deep learning. We propose an efficient Laplace approximation framework built on JAX, leveraging its automatic differentiation and functional programming capabilities to enable modular, low-dependency second-order posterior approximation. Our method supports rapid computation of posterior weight covariance for large-scale neural networks, enabling principled prediction uncertainty estimation and model selection via the Occam’s razor principle. Our primary contribution is laplax—an open-source, lightweight toolkit that systematically adapts the Laplace approximation to modern deep learning architectures and distributed training paradigms. By preserving theoretical rigor while drastically reducing implementation complexity, laplax lowers the barrier to Bayesian inference and facilitates practical uncertainty modeling and algorithmic advancement in large language models and other foundation models.

Technology Category

Application Category

📝 Abstract
The Laplace approximation provides a scalable and efficient means of quantifying weight-space uncertainty in deep neural networks, enabling the application of Bayesian tools such as predictive uncertainty and model selection via Occam's razor. In this work, we introduce laplax, a new open-source Python package for performing Laplace approximations with jax. Designed with a modular and purely functional architecture and minimal external dependencies, laplax offers a flexible and researcher-friendly framework for rapid prototyping and experimentation. Its goal is to facilitate research on Bayesian neural networks, uncertainty quantification for deep learning, and the development of improved Laplace approximation techniques.
Problem

Research questions and friction points this paper is trying to address.

Quantify weight-space uncertainty in deep neural networks
Provide scalable Bayesian tools for predictive uncertainty
Facilitate research on improved Laplace approximation techniques
Innovation

Methods, ideas, or system contributions that make the work stand out.

Laplace approximation for uncertainty quantification
Modular functional architecture with JAX
Open-source Python package for rapid prototyping
🔎 Similar Papers
No similar papers found.
T
Tobias Weber
Tübingen AI center, University of Tübingen, Tübingen, Germany
Bálint Mucsányi
Bálint Mucsányi
University of Tübingen
Uncertainty QuantificationProbabilistic MLComputer Vision
L
Lenard Rommel
Tübingen AI center, University of Tübingen, Tübingen, Germany
T
Thomas Christie
Tübingen AI center, University of Tübingen, Tübingen, Germany
L
Lars Kasüschke
Tübingen AI center, University of Tübingen, Tübingen, Germany
Marvin Pförtner
Marvin Pförtner
University of Tübingen
Gaussian ProcessesBayesian Deep LearningProbabilistic NumericsMachine Learning
Philipp Hennig
Philipp Hennig
University of Tübingen
Probabilistic NumericsMachine LearningComputer Science