Supervised learning pays attention

📅 2025-12-10
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This paper addresses the challenge in supervised learning where global models struggle to simultaneously accommodate individual heterogeneity and ensure interpretability. We propose a supervised attention mechanism that defines attention weights as data-driven sample similarity measures, dynamically weighting training instances to construct personalized local models for each test point. The method unifies Lasso regression, gradient-boosted trees, and residual correction frameworks, and naturally extends to spatiotemporal data. Its key innovations include: (i) the first formalization of supervised attention within classical supervised learning—enabling heterogeneous structure modeling without pre-specified groups; and (ii) provision of per-prediction feature importance and critical supporting-sample explanations. We theoretically establish that, under mixed data-generating mechanisms, its mean squared error strictly improves upon standard linear models. Empirical evaluations on both synthetic and real-world datasets demonstrate significant accuracy gains while preserving model parsimony and global interpretability.

Technology Category

Application Category

📝 Abstract
In-context learning with attention enables large neural networks to make context-specific predictions by selectively focusing on relevant examples. Here, we adapt this idea to supervised learning procedures such as lasso regression and gradient boosting, for tabular data. Our goals are to (1) flexibly fit personalized models for each prediction point and (2) retain model simplicity and interpretability. Our method fits a local model for each test observation by weighting the training data according to attention, a supervised similarity measure that emphasizes features and interactions that are predictive of the outcome. Attention weighting allows the method to adapt to heterogeneous data in a data-driven way, without requiring cluster or similarity pre-specification. Further, our approach is uniquely interpretable: for each test observation, we identify which features are most predictive and which training observations are most relevant. We then show how to use attention weighting for time series and spatial data, and we present a method for adapting pretrained tree-based models to distributional shift using attention-weighted residual corrections. Across real and simulated datasets, attention weighting improves predictive performance while preserving interpretability, and theory shows that attention-weighting linear models attain lower mean squared error than the standard linear model under mixture-of-models data-generating processes with known subgroup structure.
Problem

Research questions and friction points this paper is trying to address.

Adapts in-context learning to supervised methods for tabular data
Fits personalized models per prediction point while keeping interpretability
Handles heterogeneous data without pre-specifying clusters or similarities
Innovation

Methods, ideas, or system contributions that make the work stand out.

Attention weighting for personalized local models
Supervised similarity measure for data-driven adaptation
Interpretable feature and training observation identification
🔎 Similar Papers
No similar papers found.