Iterative Amortized Inference: Unifying In-Context Learning and Learned Optimizers

📅 2025-10-13
📈 Citations: 0
✹ Influential: 0
📄 PDF
đŸ€– AI Summary
Existing amortized learning approaches—including meta-learning, in-context learning, prompt tuning, and learned optimizers—face two key bottlenecks in task adaptation: (1) heterogeneous modeling of task-specific information across paradigms, and (2) poor scalability to long contexts and large-scale datasets during inference. To address these, we propose an **iterative amortized inference framework**, the first to unify in-context learning and learned optimizer paradigms under a single principled formulation. We introduce a taxonomy of amortization modes—parameterized, implicit, and explicit—and incorporate a mini-batch-based stochastic iterative update mechanism, enabling scalable adaptation to long sequences and massive datasets. Our framework significantly enhances flexibility and scalability in multi-task generalization. It establishes a novel theoretical and practical foundation for universal task adaptation, grounded in the synergistic co-design of optimization and inference.

Technology Category

Application Category

📝 Abstract
Modern learning systems increasingly rely on amortized learning - the idea of reusing computation or inductive biases shared across tasks to enable rapid generalization to novel problems. This principle spans a range of approaches, including meta-learning, in-context learning, prompt tuning, learned optimizers and more. While motivated by similar goals, these approaches differ in how they encode and leverage task-specific information, often provided as in-context examples. In this work, we propose a unified framework which describes how such methods differ primarily in the aspects of learning they amortize - such as initializations, learned updates, or predictive mappings - and how they incorporate task data at inference. We introduce a taxonomy that categorizes amortized models into parametric, implicit, and explicit regimes, based on whether task adaptation is externalized, internalized, or jointly modeled. Building on this view, we identify a key limitation in current approaches: most methods struggle to scale to large datasets because their capacity to process task data at inference (e.g., context length) is often limited. To address this, we propose iterative amortized inference, a class of models that refine solutions step-by-step over mini-batches, drawing inspiration from stochastic optimization. Our formulation bridges optimization-based meta-learning with forward-pass amortization in models like LLMs, offering a scalable and extensible foundation for general-purpose task adaptation.
Problem

Research questions and friction points this paper is trying to address.

Unifying diverse amortized learning methods into a single framework
Addressing scalability limitations in processing large task datasets
Bridging optimization-based meta-learning with forward-pass amortization approaches
Innovation

Methods, ideas, or system contributions that make the work stand out.

Iterative amortized inference refines solutions step-by-step
Unifies optimization-based meta-learning with forward-pass amortization
Scales to large datasets via mini-batch processing
🔎 Similar Papers
No similar papers found.