🤖 AI Summary
This work investigates the functional semantics encoded in embedding distances implicitly learned by Message Passing Neural Networks (MPNNs) for specific downstream tasks—moving beyond prior generalizations of MPNN distances as graph-structural metrics toward task-specific functional subgraph distances.
Method: We propose a Wasserstein optimal transport framework grounded in Weisfeiler–Leman Label Trees (WILT): WILT encodes local graph structure, while edge-weight learning quantifies the contribution of functional subgraphs to prediction-target distances. The framework achieves linear-time complexity while retaining the strong generalization power of kernel methods.
Contribution/Results: Experiments demonstrate that MPNN embedding distances are predominantly governed by a sparse set of functionally critical subgraphs. The WILT-based distance faithfully recovers these semantics, simultaneously enhancing both interpretability and predictive performance across multiple benchmarks.
📝 Abstract
We investigate the distance function learned by message passing neural networks (MPNNs) in specific tasks, aiming to capture the functional distance between prediction targets that MPNNs implicitly learn. This contrasts with previous work, which links MPNN distances on arbitrary tasks to structural distances on graphs that ignore task-specific information. To address this gap, we distill the distance between MPNN embeddings into an interpretable graph distance. Our method uses optimal transport on the Weisfeiler Leman Labeling Tree (WILT), where the edge weights reveal subgraphs that strongly influence the distance between embeddings. This approach generalizes two well-known graph kernels and can be computed in linear time. Through extensive experiments, we demonstrate that MPNNs define the relative position of embeddings by focusing on a small set of subgraphs that are known to be functionally important in the domain.