🤖 AI Summary
This work addresses the performance bottleneck in auto-vectorized multi-chain MCMC (e.g., JAX vmap) caused by inter-chain synchronization overhead. We propose a single-chain asynchronous sampling paradigm based on finite-state machines (FSMs), the first application of FSM modeling to MCMC algorithm design. We rigorously derive a theoretical upper bound on vectorization speedup and establish principled criteria for optimal algorithmic structure. Our approach is compatible with major samplers—including Elliptical Slice Sampling, HMC-NUTS, and Delayed Rejection—without modification. Implemented on GPU/TPU hardware, it achieves up to 10× end-to-end speedup over baseline synchronous implementations, significantly reducing multi-chain sampling latency and overcoming fundamental efficiency limitations of conventional synchronized paradigms.
📝 Abstract
With the advent of automatic vectorization tools (e.g., JAX's $ exttt{vmap}$), writing multi-chain MCMC algorithms is often now as simple as invoking those tools on single-chain code. Whilst convenient, for various MCMC algorithms this results in a synchronization problem -- loosely speaking, at each iteration all chains running in parallel must wait until the last chain has finished drawing its sample. In this work, we show how to design single-chain MCMC algorithms in a way that avoids synchronization overheads when vectorizing with tools like $ exttt{vmap}$ by using the framework of finite state machines (FSMs). Using a simplified model, we derive an exact theoretical form of the obtainable speed-ups using our approach, and use it to make principled recommendations for optimal algorithm design. We implement several popular MCMC algorithms as FSMs, including Elliptical Slice Sampling, HMC-NUTS, and Delayed Rejection, demonstrating speed-ups of up to an order of magnitude in experiments.