WILTing Trees: Interpreting the Distance Between MPNN Embeddings

📅 2025-05-30
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work investigates the functional semantics encoded in embedding distances implicitly learned by Message Passing Neural Networks (MPNNs) for specific downstream tasks—moving beyond prior generalizations of MPNN distances as graph-structural metrics toward task-specific functional subgraph distances. Method: We propose a Wasserstein optimal transport framework grounded in Weisfeiler–Leman Label Trees (WILT): WILT encodes local graph structure, while edge-weight learning quantifies the contribution of functional subgraphs to prediction-target distances. The framework achieves linear-time complexity while retaining the strong generalization power of kernel methods. Contribution/Results: Experiments demonstrate that MPNN embedding distances are predominantly governed by a sparse set of functionally critical subgraphs. The WILT-based distance faithfully recovers these semantics, simultaneously enhancing both interpretability and predictive performance across multiple benchmarks.

Technology Category

Application Category

📝 Abstract
We investigate the distance function learned by message passing neural networks (MPNNs) in specific tasks, aiming to capture the functional distance between prediction targets that MPNNs implicitly learn. This contrasts with previous work, which links MPNN distances on arbitrary tasks to structural distances on graphs that ignore task-specific information. To address this gap, we distill the distance between MPNN embeddings into an interpretable graph distance. Our method uses optimal transport on the Weisfeiler Leman Labeling Tree (WILT), where the edge weights reveal subgraphs that strongly influence the distance between embeddings. This approach generalizes two well-known graph kernels and can be computed in linear time. Through extensive experiments, we demonstrate that MPNNs define the relative position of embeddings by focusing on a small set of subgraphs that are known to be functionally important in the domain.
Problem

Research questions and friction points this paper is trying to address.

Interpreting distance function learned by MPNNs in tasks
Distilling MPNN embeddings into interpretable graph distance
Identifying subgraphs influencing embedding distances via WILT
Innovation

Methods, ideas, or system contributions that make the work stand out.

Uses optimal transport on WILT
Distills MPNN embeddings into interpretable distance
Computes graph distance in linear time
🔎 Similar Papers
No similar papers found.