It Just Takes Two: Scaling Amortized Inference to Large Sets

📅 2026-05-08
📈 Citations: 0
Influential: 0
📄 PDF

career value

223K/year
🤖 AI Summary
This work addresses the high computational and memory costs of amortized set-based inference when jointly processing large observation sets. The authors propose a decoupled strategy: first, a Deep Set encoder with mean pooling is trained on small sets containing at most two samples to learn set representations; subsequently, a neural posterior estimation head is fine-tuned on pre-aggregated embeddings. This two-stage approach enables efficient generalization to sets comprising thousands of elements using only pairwise training sets, substantially reducing computational overhead without sacrificing performance. Empirical results across diverse tasks—including scalar inference, image generation, multi-view 3D reconstruction, molecular property prediction, and high-dimensional conditional generation—demonstrate that the method achieves or surpasses state-of-the-art baselines at a fraction of the computational cost.
📝 Abstract
Neural posterior estimation has emerged as a powerful tool for amortized inference, with growing adoption across scientific and applied domains. In many of these applications, the conditioning variable is a set of observations whose elements depend not only on the target but also on unknown factors shared across the set. Optimal inference therefore requires treating the set jointly, which in turn requires training the estimator at the deployment set size -- a regime where memory and compute quickly become prohibitive. We introduce a simple, theoretically grounded strategy that decouples representation learning from posterior modeling. Our method trains a mean-pool Deep Set on sets of size at most two, producing an encoder that generalizes to arbitrary set sizes. The inference head is then finetuned on pre-aggregated embeddings, making training cost essentially independent of the deployment set size N. Across scalar, image, multi-view 3D, molecular, and high-dimensional conditional generation benchmarks with N in the thousands, our approach matches or outperforms standard baselines at a fraction of the compute.
Problem

Research questions and friction points this paper is trying to address.

amortized inference
neural posterior estimation
set-structured data
scalability
joint inference
Innovation

Methods, ideas, or system contributions that make the work stand out.

amortized inference
Deep Sets
set representation learning
neural posterior estimation
scalable inference
🔎 Similar Papers
2024-09-06Citations: 2