🤖 AI Summary
Neural Processes (NPs) struggle to balance accuracy and efficiency in large-scale spatiotemporal modeling—especially in geoscience, climate science, epidemiology, and robotics, where high dimensionality, multiscale structures, long-range dependencies, and translation-invariant patterns are prevalent. To address this, we propose the Biased Scan Attention Transformer NP (BSA-TNP). Our method introduces three key innovations: (1) Kernel Regression Blocks to strengthen local inductive bias; (2) group-equivariant attention biases to explicitly encode translation invariance; and (3) memory-efficient Biased Scan Attention (BSA), enabling linear-complexity modeling of long-range dependencies. On benchmarks with up to one million test points and one hundred thousand context points, BSA-TNP achieves inference in under one minute on a single 24GB GPU, significantly outperforming existing NP and Transformer baselines in accuracy. It supports joint multi-resolution learning and explicit spatiotemporal modeling.
📝 Abstract
Neural Processes (NPs) are a rapidly evolving class of models designed to directly model the posterior predictive distribution of stochastic processes. While early architectures were developed primarily as a scalable alternative to Gaussian Processes (GPs), modern NPs tackle far more complex and data hungry applications spanning geology, epidemiology, climate, and robotics. These applications have placed increasing pressure on the scalability of these models, with many architectures compromising accuracy for scalability. In this paper, we demonstrate that this tradeoff is often unnecessary, particularly when modeling fully or partially translation invariant processes. We propose a versatile new architecture, the Biased Scan Attention Transformer Neural Process (BSA-TNP), which introduces Kernel Regression Blocks (KRBlocks), group-invariant attention biases, and memory-efficient Biased Scan Attention (BSA). BSA-TNP is able to: (1) match or exceed the accuracy of the best models while often training in a fraction of the time, (2) exhibit translation invariance, enabling learning at multiple resolutions simultaneously, (3) transparently model processes that evolve in both space and time, (4) support high dimensional fixed effects, and (5) scale gracefully -- running inference with over 1M test points with 100K context points in under a minute on a single 24GB GPU.