๐ค AI Summary
This work investigates how Transformers learn the importance of historical tokens within context. To this end, we model token importance as a mixture of transition distributions governed by latent variables and demonstrate that the Transformerโs attention mechanism can learn the corresponding mixture weights. Our key contribution lies in establishing, for the first time, a formal connection between the contextual learning capability of Transformers and mirror descent algorithms, providing an exact three-layer construction that achieves a first-order approximation to Bayes-optimal prediction. Experimental results show that Transformers trained from scratch exhibit close alignment with theoretical predictions in terms of predictive distributions, attention patterns, and transition matrices, with deeper models matching the performance of multi-step mirror descent.
๐ Abstract
Sequence modelling requires determining which past tokens are causally relevant from the context and their importance: a process inherent to the attention layers in transformers, yet whose underlying learned mechanisms remain poorly understood. In this work, we formalize the task of estimating token importance as an in-context learning problem by introducing a framework based on Mixture of Transition Distributions, where a latent variable determines the influence of past tokens on the next. The distribution over this latent variable is parameterized by unobserved mixture weights that transformers must learn in-context. We demonstrate that transformers can implement Mirror Descent to learn these weights from the context. Specifically, we give an explicit construction of a three-layer transformer that exactly implements one step of Mirror Descent and prove that the resulting estimator is a first-order approximation of the Bayes-optimal predictor. Corroborating our construction and its learnability via gradient descent, we empirically show that transformers trained from scratch learn solutions consistent with our theory: their predictive distributions, attention patterns, and learned transition matrix closely match the construction, while deeper models achieve performance comparable to multi-step Mirror Descent.