Efficiently Vectorized MCMC on Modern Accelerators

📅 2025-03-20
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the performance bottleneck in auto-vectorized multi-chain MCMC (e.g., JAX vmap) caused by inter-chain synchronization overhead. We propose a single-chain asynchronous sampling paradigm based on finite-state machines (FSMs), the first application of FSM modeling to MCMC algorithm design. We rigorously derive a theoretical upper bound on vectorization speedup and establish principled criteria for optimal algorithmic structure. Our approach is compatible with major samplers—including Elliptical Slice Sampling, HMC-NUTS, and Delayed Rejection—without modification. Implemented on GPU/TPU hardware, it achieves up to 10× end-to-end speedup over baseline synchronous implementations, significantly reducing multi-chain sampling latency and overcoming fundamental efficiency limitations of conventional synchronized paradigms.

Technology Category

Application Category

📝 Abstract
With the advent of automatic vectorization tools (e.g., JAX's $ exttt{vmap}$), writing multi-chain MCMC algorithms is often now as simple as invoking those tools on single-chain code. Whilst convenient, for various MCMC algorithms this results in a synchronization problem -- loosely speaking, at each iteration all chains running in parallel must wait until the last chain has finished drawing its sample. In this work, we show how to design single-chain MCMC algorithms in a way that avoids synchronization overheads when vectorizing with tools like $ exttt{vmap}$ by using the framework of finite state machines (FSMs). Using a simplified model, we derive an exact theoretical form of the obtainable speed-ups using our approach, and use it to make principled recommendations for optimal algorithm design. We implement several popular MCMC algorithms as FSMs, including Elliptical Slice Sampling, HMC-NUTS, and Delayed Rejection, demonstrating speed-ups of up to an order of magnitude in experiments.
Problem

Research questions and friction points this paper is trying to address.

Avoid synchronization overhead in vectorized MCMC algorithms
Design single-chain MCMC using finite state machines
Achieve speed-ups in MCMC algorithms like HMC-NUTS
Innovation

Methods, ideas, or system contributions that make the work stand out.

Vectorized MCMC using automatic tools like JAX
Avoiding synchronization via finite state machines
Speed-ups demonstrated in multiple MCMC algorithms
🔎 Similar Papers
No similar papers found.