Optimizing Automatic Differentiation with Deep Reinforcement Learning

📅 2024-06-07
🏛️ arXiv.org
📈 Citations: 1
Influential: 0
📄 PDF

career value

226K/year
🤖 AI Summary
To address the high computational cost of automatic differentiation for Jacobian matrices in scientific computing, this work formulates the search for an optimal vertex elimination order as a single-agent reinforcement learning task—specifically aiming to minimize the number of multiplications required for exact Jacobian evaluation. We introduce the first end-to-end, approximation-free, and differentiable computational graph optimization method by deeply integrating the cross-country elimination framework with deep reinforcement learning (Proximal Policy Optimization, PPO). Leveraging JAX’s custom interpreter, our approach enables efficient gradient tracing and execution. Evaluated across diverse benchmarks—including machine learning, computational fluid dynamics (CFD), robotics, and finance—our method reduces multiplication counts by up to 33% compared to state-of-the-art approaches, while delivering measurable runtime speedups.

Technology Category

Application Category

📝 Abstract
Computing Jacobians with automatic differentiation is ubiquitous in many scientific domains such as machine learning, computational fluid dynamics, robotics and finance. Even small savings in the number of computations or memory usage in Jacobian computations can already incur massive savings in energy consumption and runtime. While there exist many methods that allow for such savings, they generally trade computational efficiency for approximations of the exact Jacobian. In this paper, we present a novel method to optimize the number of necessary multiplications for Jacobian computation by leveraging deep reinforcement learning (RL) and a concept called cross-country elimination while still computing the exact Jacobian. Cross-country elimination is a framework for automatic differentiation that phrases Jacobian accumulation as ordered elimination of all vertices on the computational graph where every elimination incurs a certain computational cost. We formulate the search for the optimal elimination order that minimizes the number of necessary multiplications as a single player game which is played by an RL agent. We demonstrate that this method achieves up to 33% improvements over state-of-the-art methods on several relevant tasks taken from diverse domains. Furthermore, we show that these theoretical gains translate into actual runtime improvements by providing a cross-country elimination interpreter in JAX that can efficiently execute the obtained elimination orders.
Problem

Research questions and friction points this paper is trying to address.

Automatic Differentiation
Jacobian Matrix
Computational Efficiency
Innovation

Methods, ideas, or system contributions that make the work stand out.

Deep Reinforcement Learning
Jacobian Computation Optimization
JAX Tool Enhancement