🤖 AI Summary
This work addresses the lack of theoretical understanding regarding how Transformers approximate the posterior predictive distribution (PPD) in in-context learning, particularly for Gaussian process regression tasks. Through a constructive proof, we demonstrate for the first time that Transformers can embed gradient descent within their attention mechanism to effectively approximate both the mean and variance of the PPD, subsequently mapping these estimates through a nonlinear transformation to yield discretized predictive probabilities. We derive theoretical bounds on the approximation error in terms of attention depth and binning resolution, and further elucidate the critical roles of normalization and attention depth in enabling out-of-distribution generalization. Empirical results validate the proposed architecture’s strong generalization performance, providing a theoretical foundation for Transformers’ capacity to learn full Bayesian posteriors.
📝 Abstract
Prior-data fitted networks (PFNs) have recently emerged as a powerful approach for Bayesian prediction tasks, approximating the posterior predictive distribution (PPD) through in-context learning. Despite their strong empirical performance and ability to go beyond point predictions, theoretical understandings of the algorithmic capability of transformers to learn distributions in context are still lacking. Focusing on Gaussian process regression problems, we show by construction that transformers can implement a gradient descent algorithm targeting the posterior predictive mean and variance, followed by nonlinear mappings that yield binned probabilities of PPD. We study the error bounds of the approximated PPD in terms of attention depth and bin resolution. Based on these results, we further demonstrate the key role of normalization and the choice of attention depth in enabling the extrapolation abilities of transformers beyond the pretraining sample size range. We conduct simulations that corroborate our findings, providing insight into the expressivity of PFNs targeting PPDs and how architectural choices may influence generalization capabilities.