🤖 AI Summary
SciPy’s `spatial.transform` module currently supports only NumPy, hindering its integration into GPU-accelerated and differentiable machine learning workflows. Method: We comprehensively refactor the module to achieve full backend interoperability with JAX, PyTorch, and CuPy—while adhering to the Python Array API Standard—enabling hardware acceleration (GPU/TPU), automatic differentiation, and JIT compilation for 3D rigid-body transformations (rotations and translations). The redesign preserves NumPy-like APIs, ensures vectorized batch processing, and supports high-precision arithmetic. Contribution/Results: The restructured module has been merged into SciPy’s main branch and will be released in the upcoming version. It significantly enhances reliability, computational efficiency, and end-to-end differentiability of spatial transformations in differentiable systems—including robotics, computer vision, and physics simulation—without compromising usability or numerical robustness.
📝 Abstract
Three-dimensional rigid-body transforms, i.e. rotations and translations, are central to modern differentiable machine learning pipelines in robotics, vision, and simulation. However, numerically robust and mathematically correct implementations, particularly on SO(3), are error-prone due to issues such as axis conventions, normalizations, composition consistency and subtle errors that only appear in edge cases. SciPy's spatial.transform module is a rigorously tested Python implementation. However, it historically only supported NumPy, limiting adoption in GPU-accelerated and autodiff-based workflows. We present a complete overhaul of SciPy's spatial.transform functionality that makes it compatible with any array library implementing the Python array API, including JAX, PyTorch, and CuPy. The revised implementation preserves the established SciPy interface while enabling GPU/TPU execution, JIT compilation, vectorized batching, and differentiation via native autodiff of the chosen backend. We demonstrate how this foundation supports differentiable scientific computing through two case studies: (i) scalability of 3D transforms and rotations and (ii) a JAX drone simulation that leverages SciPy's Rotation for accurate integration of rotational dynamics. Our contributions have been merged into SciPy main and will ship in the next release, providing a framework-agnostic, production-grade basis for 3D spatial math in differentiable systems and ML.