🤖 AI Summary
Existing reverse-mode automatic differentiation (AD) for multidimensional array programs suffers from high runtime overhead and challenges in optimizing higher-order constructs. This paper introduces a dual-number–based reverse AD method tailored for functional array languages. Our approach addresses these issues through three key contributions: (1) a semantics-preserving vectorizing compilation transformation that lowers higher-order array operations to batched primitives; (2) a systematic extension of the dual-number mechanism to array types, accompanied by a symbolic interpretation enabling compile-time propagation of derivative information; and (3) restricted yet optimization-friendly support for critical higher-order combinators—namely, `build`, `gather`, and `scatter`. The resulting end-to-end compilation pipeline drastically reduces AD overhead, achieving near-zero runtime cost on typical array programs while preserving strong compile-time optimization capabilities. This breakthrough overcomes longstanding performance bottlenecks in array-oriented AD.
📝 Abstract
The standard dual-numbers construction works well for forward-mode automatic differentiation (AD) and is attractive due to its simplicity; recently, it also has been adapted to reverse-mode AD, but practical performance, especially on array programs, leaves a lot to be desired. In this paper we introduce first-class support for multidimensional arrays in dual-numbers reverse-mode AD with little to no performance overhead. The algorithm consists of three loosely-coupled components: a semantics-preserving vectorisation code transformation (the bulk-operation transform or BOT), a fairly straightforward lifting of the basic dual-numbers reverse AD algorithm to a mostly first-order array language, and symbolic interpretation to achieve an end-to-end compilation pipeline. Unfortunately, we lose some of the nice generalisable aspects of dual-numbers AD in the process, most importantly support for higher-order code.
We do support some higher-order array combinators, but only a carefully-chosen set: 'build' (elementwise array construction), 'gather' and 'scatter'. In return, the BOT can eliminate the essential (for AD) higher-orderness of the input program, meaning that AD gets essentially presented with a first-order program. This allows the naive trick of lifting dual numbers to "dual arrays" to work without much modification.