JetSCI: A Hybrid JAX-PETSc Framework for Scalable Differentiable Simulation

📅 2026-04-23
📈 Citations: 0
Influential: 0
📄 PDF

career value

230K/year
🤖 AI Summary
This work addresses the limitations of pure JAX frameworks in scalability and robustness for large-scale differentiable scientific computing by introducing JetSCI, a novel framework that deeply integrates JAX with PETSc for the first time. JetSCI leverages JAX to enable GPU-accelerated automatic differentiation and differentiable discretization, while harnessing PETSc’s high-performance solvers for efficiently tackling large-scale linear and nonlinear systems on distributed-memory architectures. The framework supports multi-level parallelism, combining intra-GPU and MPI-based inter-node concurrency. Demonstrated on heterogeneous micromechanical finite element simulations, JetSCI achieves substantial improvements over pure JAX in both computational efficiency and numerical accuracy. By incorporating a mature high-performance computing (HPC) solver stack while preserving machine learning–friendly differentiability, JetSCI significantly enhances the scalability and stability of large-scale differentiable simulations.

Technology Category

Application Category

📝 Abstract
The rapid rise of scientific machine learning (SciML) has expanded the role of differentiable modeling, surrogate modeling, and data-driven constitutive laws in large-scale simulation. The JAX framework provides an attractive environment for these workflows through automatically differentiable programs, vectorization, GPU acceleration, and while enabling seamless learning of surrogate models. However, large-scale simulation still relies on mature HPC infrastructure. Libraries, such as PETSc, provide scalable MPI-based parallelism, robust linear and nonlinear solvers, and advanced preconditioning capabilities that remain difficult to reproduce in JAX-only workflows. We present JetSCI, a hybrid JAX-PETSc framework that unifies these complementary strengths. JetSCI uses JAX for GPU-parallel differentiable discretizations and PETSc for robust, scalable solution of the resulting systems on distributed-memory architectures, exposing multilevel parallelism through GPU acceleration within nodes and MPI parallelism across nodes. For finite element discretizations of heterogeneous micromechanics problems, JetSCI outperforms JAX-only implementations in efficiency and accuracy.
Problem

Research questions and friction points this paper is trying to address.

differentiable simulation
scientific machine learning
JAX
PETSc
scalable HPC
Innovation

Methods, ideas, or system contributions that make the work stand out.

differentiable simulation
JAX-PETSc hybrid framework
multilevel parallelism
scientific machine learning
scalable HPC
🔎 Similar Papers
2024-07-30International Conference on SystemsCitations: 1