🤖 AI Summary
This paper investigates exact learning of arbitrary permutations on $k$-dimensional nonzero binary inputs using two-layer fully connected neural networks in the Neural Tangent Kernel (NTK) infinite-width limit. To overcome the $Omega(k)$ sample complexity and high computational cost of conventional approaches, we establish the first result showing that only $O(log k)$ training samples—logarithmic in the number of standard basis vectors—are sufficient for exact learning. We propose a multi-model ensemble mechanism based on sign-feature extraction and threshold rounding, and theoretically derive tight upper bounds on ensemble complexity: $O(k log k)$ for pointwise generalization and $O(k^2)$ for global exact learning. Furthermore, we provide an exact characterization of gradient descent dynamics as an analytically tractable Gaussian process. Both theoretical analysis and empirical experiments demonstrate that our method achieves 100% permutation prediction accuracy with high probability on any $k$-dimensional nonzero binary input.
📝 Abstract
The ability of an architecture to realize permutations is quite fundamental. For example, Large Language Models need to be able to correctly copy (and perhaps rearrange) parts of the input prompt into the output. Classical universal approximation theorems guarantee the existence of parameter configurations that solve this task but offer no insights into whether gradient-based algorithms can find them. In this paper, we address this gap by focusing on two-layer fully connected feed-forward neural networks and the task of learning permutations on nonzero binary inputs. We show that in the infinite width Neural Tangent Kernel (NTK) regime, an ensemble of such networks independently trained with gradient descent on only the $k$ standard basis vectors out of $2^k - 1$ possible inputs successfully learns any fixed permutation of length $k$ with arbitrarily high probability. By analyzing the exact training dynamics, we prove that the network's output converges to a Gaussian process whose mean captures the ground truth permutation via sign-based features. We then demonstrate how averaging these runs (an"ensemble"method) and applying a simple rounding step yields an arbitrarily accurate prediction on any possible input unseen during training. Notably, the number of models needed to achieve exact learning with high probability (which we refer to as ensemble complexity) exhibits a linearithmic dependence on the input size $k$ for a single test input and a quadratic dependence when considering all test inputs simultaneously.