🤖 AI Summary
Deep ReLU networks suffer from neuron death (“dying ReLU”) and gradient instability due to suboptimal weight initialization; conventional schemes (e.g., He, Xavier, orthogonal initialization) fail to jointly control pre-activation mean, sparsity, and variance stability—especially in extremely deep architectures.
Method: We formulate weight initialization as an optimization problem on the Stiefel manifold, explicitly incorporating the ReLU nonlinearity’s statistical prior, and derive a closed-form family of orthogonal initializations with an efficient sampling strategy.
Contribution/Results: Our method theoretically guarantees exact calibration of pre-activation statistics, thereby eliminating neuron death at initialization and mitigating gradient vanishing and variance decay. Empirically, it consistently outperforms state-of-the-art initialization methods across MNIST, Fashion-MNIST, tabular datasets, and few-shot learning tasks. Notably, it maintains training stability and convergence even in networks exceeding 100 layers—where existing approaches collapse.
📝 Abstract
Stable and efficient training of ReLU networks with large depth is highly sensitive to weight initialization. Improper initialization can cause permanent neuron inactivation dying ReLU and exacerbate gradient instability as network depth increases. Methods such as He, Xavier, and orthogonal initialization preserve variance or promote approximate isometry. However, they do not necessarily regulate the pre-activation mean or control activation sparsity, and their effectiveness often diminishes in very deep architectures. This work introduces an orthogonal initialization specifically optimized for ReLU by solving an optimization problem on the Stiefel manifold, thereby preserving scale and calibrating the pre-activation statistics from the outset. A family of closed-form solutions and an efficient sampling scheme are derived. Theoretical analysis at initialization shows that prevention of the dying ReLU problem, slower decay of activation variance, and mitigation of gradient vanishing, which together stabilize signal and gradient flow in deep architectures. Empirically, across MNIST, Fashion-MNIST, multiple tabular datasets, few-shot settings, and ReLU-family activations, our method outperforms previous initializations and enables stable training in deep networks.