A projection-based framework for gradient-free and parallel learning

📅 2025-06-06
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This paper addresses the serial bottleneck in neural network training caused by gradient computation and the difficulty of supporting non-differentiable operations. To this end, it proposes a gradient-free parallel training paradigm. Methodologically, training is reformulated as a large-scale feasibility problem, where local constraint projections replace gradient-based optimization. The authors introduce PJAX—a novel framework that automatically synthesizes constraint projection operators (analogous to automatic differentiation, but for feasibility solving), implemented atop JAX with a NumPy-like API. PJAX supports compositional projections and iterative projection algorithms, and natively leverages GPU/TPU acceleration. Experiments demonstrate successful training of MLPs, CNNs, and RNNs on standard benchmarks, achieving accuracy comparable to gradient-based methods while significantly improving parallel efficiency and natively accommodating non-differentiable components.

Technology Category

Application Category

📝 Abstract
We present a feasibility-seeking approach to neural network training. This mathematical optimization framework is distinct from conventional gradient-based loss minimization and uses projection operators and iterative projection algorithms. We reformulate training as a large-scale feasibility problem: finding network parameters and states that satisfy local constraints derived from its elementary operations. Training then involves projecting onto these constraints, a local operation that can be parallelized across the network. We introduce PJAX, a JAX-based software framework that enables this paradigm. PJAX composes projection operators for elementary operations, automatically deriving the solution operators for the feasibility problems (akin to autodiff for derivatives). It inherently supports GPU/TPU acceleration, provides a familiar NumPy-like API, and is extensible. We train diverse architectures (MLPs, CNNs, RNNs) on standard benchmarks using PJAX, demonstrating its functionality and generality. Our results show that this approach is as a compelling alternative to gradient-based training, with clear advantages in parallelism and the ability to handle non-differentiable operations.
Problem

Research questions and friction points this paper is trying to address.

Proposes gradient-free neural network training using projection operators
Reformulates training as parallelizable feasibility problem with constraints
Introduces PJAX framework for GPU/TPU-accelerated projection-based learning
Innovation

Methods, ideas, or system contributions that make the work stand out.

Projection-based framework for gradient-free learning
Parallel training using iterative projection algorithms
PJAX software for automatic feasibility problem solving
🔎 Similar Papers
No similar papers found.