🤖 AI Summary
In Neural Stochastic Differential Equation (Neural SDE) training, conventional “discretize-then-optimize” approaches yield accurate gradients but incur high memory overhead, whereas “optimize-then-discretize” methods achieve constant memory usage at the cost of gradient bias and slow inference. Algebraically invertible solvers (e.g., Reversible Heun) balance memory efficiency and accuracy but suffer from numerical instability under large step sizes or complex dynamics. This work introduces the Explicit Efficient Symmetric Runge–Kutta (EES) scheme—the first explicit, algebraically invertible, and numerically stable RK method for Neural SDEs. EES enables exact forward solving and exact backward gradient propagation, ensuring constant memory complexity and unbiased gradients. Experiments across diverse Neural SDE tasks demonstrate substantial improvements in training stability, step-size tolerance, and model scalability.
📝 Abstract
Backpropagation through (neural) SDE solvers is traditionally approached in two ways: discretise-then-optimise, which offers accurate gradients but incurs prohibitive memory costs due to storing the full computational graph (even when mitigated by checkpointing); and optimise-then-discretise, which achieves constant memory cost by solving an auxiliary backward SDE, but suffers from slower evaluation and gradient approximation errors. Algebraically reversible solvers promise both memory efficiency and gradient accuracy, yet existing methods such as the Reversible Heun scheme are often unstable under complex models and large step sizes. We address these limitations by introducing a novel class of stable, near-reversible Runge--Kutta schemes for neural SDEs. These Explicit and Effectively Symmetric (EES) schemes retain the benefits of reversible solvers while overcoming their instability, enabling memory-efficient training without severe restrictions on step size or model complexity. Through numerical experiments, we demonstrate the superior stability and reliability of our schemes, establishing them as a practical foundation for scalable and accurate training of neural SDEs.