🤖 AI Summary
This work addresses context learning for classification tasks by proposing a novel attention-mechanism modeling paradigm grounded in the functional gradient descent (GD) framework. Methodologically, it introduces a network architecture integrating self-attention and cross-attention with residual connections to explicitly model multi-step functional GD inference—achieving exact, interpretable, multi-step contextual reasoning under classification observations for the first time. Theoretically, it relaxes conventional attention models’ reliance on kernel-function or distributional assumptions, generalizing the class of attention mechanisms applicable to functional GD. Empirically, the approach demonstrates significant improvements in few-shot classification accuracy and conditional language generation quality across synthetic data, few-shot image classification, and text generation benchmarks. This work establishes a new theoretical foundation and practical architecture for context learning, advancing both interpretability and performance in data-scarce settings.
📝 Abstract
In-context learning based on attention models is examined for data with categorical outcomes, with inference in such models viewed from the perspective of functional gradient descent (GD). We develop a network composed of attention blocks, with each block employing a self-attention layer followed by a cross-attention layer, with associated skip connections. This model can exactly perform multi-step functional GD inference for in-context inference with categorical observations. We perform a theoretical analysis of this setup, generalizing many prior assumptions in this line of work, including the class of attention mechanisms for which it is appropriate. We demonstrate the framework empirically on synthetic data, image classification and language generation.