🤖 AI Summary
Over-smoothing in Graph Neural Networks (GNNs) degrades node representation discriminability by collapsing distinct features into indistinguishable embeddings. Method: Grounded in spectral graph theory, we formulate a linearized GNN model and rigorously prove that residual connections preserve the initial feature subspace, while batch normalization prevents embedding collapse. We further propose GraphNormv2, a learnable centering-based normalization layer that enhances feature distinguishability without distorting signal fidelity. Contribution/Results: Theoretically, node representations converge to the subspace spanned by the top-k principal eigenvectors of the message-passing operator. Empirically, GraphNormv2 significantly improves both expressivity and generalization of GNNs across multiple benchmark datasets, outperforming standard normalization schemes.
📝 Abstract
Residual connections and normalization layers have become standard design choices for graph neural networks (GNNs), and were proposed as solutions to the mitigate the oversmoothing problem in GNNs. However, how exactly these methods help alleviate the oversmoothing problem from a theoretical perspective is not well understood. In this work, we provide a formal and precise characterization of (linearized) GNNs with residual connections and normalization layers. We establish that (a) for residual connections, the incorporation of the initial features at each layer can prevent the signal from becoming too smooth, and determines the subspace of possible node representations; (b) batch normalization prevents a complete collapse of the output embedding space to a one-dimensional subspace through the individual rescaling of each column of the feature matrix. This results in the convergence of node representations to the top-$k$ eigenspace of the message-passing operator; (c) moreover, we show that the centering step of a normalization layer -- which can be understood as a projection -- alters the graph signal in message-passing in such a way that relevant information can become harder to extract. We therefore introduce a novel, principled normalization layer called GraphNormv2 in which the centering step is learned such that it does not distort the original graph signal in an undesirable way. Experimental results confirm the effectiveness of our method.