🤖 AI Summary
This work investigates the sample complexity of shallow neural networks trained via gradient descent when input features exhibit a clustered correlation structure. Specifically, under a setting where features are grouped into clusters generated by a small number of latent variables, the authors propose an analytically tractable theoretical model and design a hierarchical gradient descent algorithm for learning. Theoretical analysis demonstrates that, under high signal-to-noise ratio conditions, the required sample size depends only on the number of latent variables and exhibits merely logarithmic dependence on the input dimension—thereby circumventing the conventional assumption of low-dimensional target functions. These findings are corroborated through experiments on both synthetic and real-world datasets, shedding light on the mechanisms enabling efficient learning in neural networks with structured inputs.
📝 Abstract
The success of deep learning in high-dimensional settings is often attributed to the presence of low-dimensional structure in real-world data. While standard theoretical models typically assume that this structure lies in the target function, projecting unstructured inputs onto a low-dimensional subspace, data such as images, text or genomic sequences exhibit strong spatial correlations within the input space itself. In this paper, we propose a tractable model to study how these correlations affect the sample complexity of learning with gradient descent on shallow neural networks. Specifically, we consider targets that depend on a small number of latent Boolean variables, and input features grouped into clusters and correlated with the latent variables. Under an identifiability assumption, we show that for a layerwise gradient-descent variant, the sample complexity scales with the number of hidden variables and, when the signal-to-noise ratio is sufficiently high, is independent of the input dimension, up to logarithmic terms. We empirically test our theoretical findings on both synthetic and real data.