๐ค AI Summary
This work addresses the lack of structural priors and efficient representation in neural network gradient modeling. We propose GradMetaNetโthe first gradient-specific architecture that explicitly encodes neuron permutation equivariance, multi-point gradient sets, and curvature information. Methodologically, it integrates equivariant neural modules, set encoders, and rank-1 gradient decomposition to construct a hierarchical gradient manifold encoding framework; we theoretically establish its universal approximation capability for arbitrary gradient functions. Empirical evaluation on MLPs and Transformers demonstrates that GradMetaNet significantly outperforms existing approaches in gradient-based learning optimization, implicit neural representation editing, and loss landscape curvature estimation. It is the first method to achieve structured, interpretable, and high-fidelity modeling of gradient functions.
๐ Abstract
Gradients of neural networks encode valuable information for optimization, editing, and analysis of models. Therefore, practitioners often treat gradients as inputs to task-specific algorithms, e.g. for pruning or optimization. Recent works explore learning algorithms that operate directly on gradients but use architectures that are not specifically designed for gradient processing, limiting their applicability. In this paper, we present a principled approach for designing architectures that process gradients. Our approach is guided by three principles: (1) equivariant design that preserves neuron permutation symmetries, (2) processing sets of gradients across multiple data points to capture curvature information, and (3) efficient gradient representation through rank-1 decomposition. Based on these principles, we introduce GradMetaNet, a novel architecture for learning on gradients, constructed from simple equivariant blocks. We prove universality results for GradMetaNet, and show that previous approaches cannot approximate natural gradient-based functions that GradMetaNet can. We then demonstrate GradMetaNet's effectiveness on a diverse set of gradient-based tasks on MLPs and transformers, such as learned optimization, INR editing, and estimating loss landscape curvature.