🤖 AI Summary
This work addresses the limitations of pure JAX frameworks in scalability and robustness for large-scale differentiable scientific computing by introducing JetSCI, a novel framework that deeply integrates JAX with PETSc for the first time. JetSCI leverages JAX to enable GPU-accelerated automatic differentiation and differentiable discretization, while harnessing PETSc’s high-performance solvers for efficiently tackling large-scale linear and nonlinear systems on distributed-memory architectures. The framework supports multi-level parallelism, combining intra-GPU and MPI-based inter-node concurrency. Demonstrated on heterogeneous micromechanical finite element simulations, JetSCI achieves substantial improvements over pure JAX in both computational efficiency and numerical accuracy. By incorporating a mature high-performance computing (HPC) solver stack while preserving machine learning–friendly differentiability, JetSCI significantly enhances the scalability and stability of large-scale differentiable simulations.
📝 Abstract
The rapid rise of scientific machine learning (SciML) has expanded the role of differentiable modeling, surrogate modeling, and data-driven constitutive laws in large-scale simulation. The JAX framework provides an attractive environment for these workflows through automatically differentiable programs, vectorization, GPU acceleration, and while enabling seamless learning of surrogate models. However, large-scale simulation still relies on mature HPC infrastructure. Libraries, such as PETSc, provide scalable MPI-based parallelism, robust linear and nonlinear solvers, and advanced preconditioning capabilities that remain difficult to reproduce in JAX-only workflows. We present JetSCI, a hybrid JAX-PETSc framework that unifies these complementary strengths. JetSCI uses JAX for GPU-parallel differentiable discretizations and PETSc for robust, scalable solution of the resulting systems on distributed-memory architectures, exposing multilevel parallelism through GPU acceleration within nodes and MPI parallelism across nodes. For finite element discretizations of heterogeneous micromechanics problems, JetSCI outperforms JAX-only implementations in efficiency and accuracy.