Learning Gaussian Multi-Index Models with Gradient Flow: Time Complexity and Directional Convergence

📅 2024-11-13
🏛️ arXiv.org
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work analyzes the gradient flow dynamics of neural networks trained with correlation loss to learn multi-index functions $f(x)=sum_{j=1}^k sigma^*(v_j^ op x)$ under high-dimensional standard Gaussian inputs. Addressing the high polynomial time complexity and susceptibility to suboptimal solutions in conventional search-phase methods, we extend single-index polynomial convergence guarantees to arbitrary multi-index settings for the first time. We systematically characterize fixed-point structures and global convergence under orthogonal and equiangular index configurations, revealing a critical angular threshold $eta_c = c/(c+k)$ that induces a phase transition in the optimization landscape. Theoretically, we prove that under orthogonality, only $n asymp k log k$ neurons suffice to recover all index vectors ${v_j}$ with high probability. Numerical experiments confirm that mild overparameterization succeeds near orthogonality but fails under high index correlations.

Technology Category

Application Category

📝 Abstract
This work focuses on the gradient flow dynamics of a neural network model that uses correlation loss to approximate a multi-index function on high-dimensional standard Gaussian data. Specifically, the multi-index function we consider is a sum of neurons $f^*(x) !=! sum_{j=1}^k ! sigma^*(v_j^T x)$ where $v_1, dots, v_k$ are unit vectors, and $sigma^*$ lacks the first and second Hermite polynomials in its Hermite expansion. It is known that, for the single-index case ($k!=!1$), overcoming the search phase requires polynomial time complexity. We first generalize this result to multi-index functions characterized by vectors in arbitrary directions. After the search phase, it is not clear whether the network neurons converge to the index vectors, or get stuck at a sub-optimal solution. When the index vectors are orthogonal, we give a complete characterization of the fixed points and prove that neurons converge to the nearest index vectors. Therefore, using $n ! asymp ! k log k$ neurons ensures finding the full set of index vectors with gradient flow with high probability over random initialization. When $ v_i^T v_j !=! eta ! geq ! 0$ for all $i eq j$, we prove the existence of a sharp threshold $eta_c !=! c/(c+k)$ at which the fixed point that computes the average of the index vectors transitions from a saddle point to a minimum. Numerical simulations show that using a correlation loss and a mild overparameterization suffices to learn all of the index vectors when they are nearly orthogonal, however, the correlation loss fails when the dot product between the index vectors exceeds a certain threshold.
Problem

Research questions and friction points this paper is trying to address.

Analyzes gradient flow dynamics in neural networks for multi-index functions.
Determines conditions for neurons to converge to index vectors.
Identifies thresholds for learning index vectors with correlation loss.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Gradient flow dynamics for multi-index functions
Neurons converge to nearest orthogonal index vectors
Correlation loss with overparameterization learns index vectors
🔎 Similar Papers
No similar papers found.