🤖 AI Summary
This work investigates how Transformers model hierarchical positional dependencies in tree-structured data. We propose a hierarchical filtering generative model that enables controlled modulation of positional dependencies across multiple scales, and integrate attention map analysis, hierarchical probing, and encoder-only training to systematically characterize the underlying modeling mechanisms. We make the first discovery that Transformer encoder layers progressively capture long-range hierarchical dependencies with depth: shallow layers encode local adjacency relations, while deeper layers specialize in global tree topology; moreover, each layer approximately reconstructs correlation patterns at a specific scale. Empirical results demonstrate that this architecture achieves performance approaching exact Bayesian inference on trees for root-node classification and masked language modeling tasks. Our findings establish a verifiable, scale-separated computational mechanism for interpretable AI, grounded in principled hierarchical representation learning.
📝 Abstract
Understanding the learning process and the embedded computation in transformers is becoming a central goal for the development of interpretable AI. In the present study, we introduce a hierarchical filtering procedure for generative models of sequences on trees, allowing us to hand-tune the range of positional correlations in the data. Leveraging this controlled setting, we provide evidence that vanilla encoder-only transformers can approximate the exact inference algorithm when trained on root classification and masked language modeling tasks, and study how this computation is discovered and implemented. We find that correlations at larger distances, corresponding to increasing layers of the hierarchy, are sequentially included by the network during training. Moreover, by comparing attention maps from models trained with varying degrees of filtering and by probing the different encoder levels, we find clear evidence of a reconstruction of correlations on successive length scales corresponding to the various levels of the hierarchy, which we relate to a plausible implementation of the exact inference algorithm within the same architecture.