A lift for input-convex neural network training

📅 2026-05-22
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the optimization challenges in Input Convex Neural Networks (ICNNs), where non-negative weight constraints often lead to vanishing gradients and training stagnation. To overcome these limitations, the authors propose a hypernetwork-based “lift” framework that generates ICNN weights from permutation-invariant summaries of input batches via an unconstrained hypernetwork. The approach incorporates learnable biases, batch conditioning, and a cross-covariance regularization term to soften the loss landscape and alleviate optimization plateaus. Evaluated on log-concave energy modeling and convex potential normalizing flows, the method significantly outperforms projection-based gradient descent and Softplus reparameterization, achieving lower test losses and enabling training trajectories to transition from flat plateaus to sustained descent.
📝 Abstract
Input-convex neural networks (ICNNs) are widely used for log-concave density estimation, convex-potential normalizing flows, optimal transport, and transport-map inversion for high-dimensional Bayesian posteriors. These tasks share a structural constraint: the inter-layer weights of the ICNN must remain non-negative. The standard recipe, projected gradient descent (PGD) onto the non-negative cone, applies a hard, non-smooth projection -- the stiff-penalty limit of an ADMM-style constraint splitting -- and its classical convergence guarantees do not transfer to the non-smooth ICNN training landscape; the differentiable alternative, softplus reparametrization, attenuates the gradient exponentially in the weight magnitude, stalling training with dead inter-layer weights and plateaued loss. Inspired by parameter-extension lifts of PDE-constrained inverse problems, we propose the lift: instead of constraining the inter-layer weights directly, we train an unconstrained hypernetwork that emits them from a permutation-invariant summary of the input batch. This adds stochasticity to the training dynamics that softens the loss landscape, letting the iterates escape the gradient-attenuated region where direct softplus stalls. We trace this softening to three structural ingredients -- a learnable bias acting as slack, a hypernetwork body that conditions on the target batch, and a cross-covariance coupling the two through batch stochasticity -- and prove each one necessary: deleting any single ingredient collapses the cross-covariance that carries the softening. On log-concave energy-based modeling from one-dimensional toy targets to image-flavored latents, and convex-potential normalizing flows on a 21-dimensional tabular benchmark, we show that the lift reaches a lower test loss than both PGD and direct softplus, and turns a plateau-bounded training trajectory into a valley-descending one.
Problem

Research questions and friction points this paper is trying to address.

input-convex neural networks
non-negative weights
gradient attenuation
training stagnation
structural constraints
Innovation

Methods, ideas, or system contributions that make the work stand out.

input-convex neural networks
hypernetwork
lifted optimization
non-negative weight constraints
batch stochasticity
🔎 Similar Papers
No similar papers found.