Residual Connections and Normalization Can Provably Prevent Oversmoothing in GNNs

📅 2024-06-05
🏛️ International Conference on Learning Representations
📈 Citations: 7
Influential: 0
📄 PDF
🤖 AI Summary
Over-smoothing in Graph Neural Networks (GNNs) degrades node representation discriminability by collapsing distinct features into indistinguishable embeddings. Method: Grounded in spectral graph theory, we formulate a linearized GNN model and rigorously prove that residual connections preserve the initial feature subspace, while batch normalization prevents embedding collapse. We further propose GraphNormv2, a learnable centering-based normalization layer that enhances feature distinguishability without distorting signal fidelity. Contribution/Results: Theoretically, node representations converge to the subspace spanned by the top-k principal eigenvectors of the message-passing operator. Empirically, GraphNormv2 significantly improves both expressivity and generalization of GNNs across multiple benchmark datasets, outperforming standard normalization schemes.

Technology Category

Application Category

📝 Abstract
Residual connections and normalization layers have become standard design choices for graph neural networks (GNNs), and were proposed as solutions to the mitigate the oversmoothing problem in GNNs. However, how exactly these methods help alleviate the oversmoothing problem from a theoretical perspective is not well understood. In this work, we provide a formal and precise characterization of (linearized) GNNs with residual connections and normalization layers. We establish that (a) for residual connections, the incorporation of the initial features at each layer can prevent the signal from becoming too smooth, and determines the subspace of possible node representations; (b) batch normalization prevents a complete collapse of the output embedding space to a one-dimensional subspace through the individual rescaling of each column of the feature matrix. This results in the convergence of node representations to the top-$k$ eigenspace of the message-passing operator; (c) moreover, we show that the centering step of a normalization layer -- which can be understood as a projection -- alters the graph signal in message-passing in such a way that relevant information can become harder to extract. We therefore introduce a novel, principled normalization layer called GraphNormv2 in which the centering step is learned such that it does not distort the original graph signal in an undesirable way. Experimental results confirm the effectiveness of our method.
Problem

Research questions and friction points this paper is trying to address.

How residual connections prevent oversmoothing in GNNs
How normalization layers affect GNN output embedding space
Designing a principled normalization layer to preserve graph signals
Innovation

Methods, ideas, or system contributions that make the work stand out.

Residual connections prevent oversmoothing in GNNs
Batch normalization avoids output space collapse
GraphNormv2 learns non-distorting centering step
🔎 Similar Papers
No similar papers found.
Michael Scholkemper
Michael Scholkemper
RWTH Aachen University
Machine Learning on GraphsRole Extraction
X
Xinyi Wu
Institute for Data, Systems, and Society, Massachusetts Institute of Technology
A
A. Jadbabaie
Institute for Data, Systems, and Society, Massachusetts Institute of Technology
Michael T. Schaub
Michael T. Schaub
RWTH Aachen University
NetworksApplied Dynamical SystemsNeuroscienceData ScienceGraph Signal Processing