🤖 AI Summary
Second-order optimization methods remain underutilized due to their implementation complexity, sensitivity to hyperparameters, and lack of composable interfaces. This work proposes Somax, a composable framework natively compatible with Optax that explicitly exposes curvature operators, estimators, and linear solvers as first-class components, enabling flexible composition and efficient reuse. By decoupling computation graph construction from execution through static scheduling and leveraging JIT compilation, Somax encapsulates curvature-aware training into a single-step operation that seamlessly integrates standard optimization techniques such as momentum and weight decay. Systematic experiments demonstrate that well-chosen module compositions significantly improve convergence speed and scalability, while static scheduling effectively reduces per-step computational overhead.
📝 Abstract
Second-order methods promise improved stability and faster convergence, yet they remain underused due to implementation overhead, tuning brittleness, and the lack of composable APIs. We introduce Somax, a composable Optax-native stack that treats curvature-aware training as a single JIT-compiled step governed by a static plan. Somax exposes first-class modules -- curvature operators, estimators, linear solvers, preconditioners, and damping policies -- behind a single step interface and composes with Optax by applying standard gradient transformations (e.g., momentum, weight decay, schedules) to the computed direction. This design makes typically hidden choices explicit and swappable. Somax separates planning from execution: it derives a static plan (including cadences) from module requirements, then runs the step through a specialized execution path that reuses intermediate results across modules. We report system-oriented ablations showing that (i) composition choices materially affect scaling behavior and time-to-accuracy, and (ii) planning reduces per-step overhead relative to unplanned composition with redundant recomputation.