JAXMg: A multi-GPU linear solver in JAX

📅 2026-01-20
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the challenge of efficiently scaling dense linear algebra operations—specifically linear systems and eigenvalue problems—to multiple GPUs while maintaining compatibility with just-in-time (JIT) compilation in modern programming frameworks. We present the first implementation of composable, JIT-compatible multi-GPU dense linear algebra primitives in JAX by integrating NVIDIA’s cuSOLVERMg library through the XLA Foreign Function Interface. This integration enables out-of-core Cholesky factorization and symmetric eigendecomposition at unprecedented scales, overcoming single-GPU memory limitations. Crucially, our approach seamlessly interoperates with JAX’s native transformations, including automatic differentiation and vectorization, thereby delivering scalable, high-performance distributed dense linear algebra capabilities for end-to-end scientific computing workflows.

Technology Category

Application Category

📝 Abstract
Solving large dense linear systems and eigenvalue problems is a core requirement in many areas of scientific computing, but scaling these operations beyond a single GPU remains challenging within modern programming frameworks. While highly optimized multi-GPU solver libraries exist, they are typically difficult to integrate into composable, just-in-time (JIT) compiled Python workflows. JAXMg provides multi-GPU dense linear algebra for JAX, enabling Cholesky-based linear solves and symmetric eigendecompositions for matrices that exceed single-GPU memory limits. By interfacing JAX with NVIDIA's cuSOLVERMg through an XLA Foreign Function Interface, JAXMg exposes distributed GPU solvers as JIT-compatible JAX primitives. This design allows scalable linear algebra to be embedded directly within JAX programs, preserving composability with JAX transformations and enabling multi-GPU execution in end-to-end scientific workflows.
Problem

Research questions and friction points this paper is trying to address.

multi-GPU
dense linear systems
eigenvalue problems
JIT compilation
scientific computing
Innovation

Methods, ideas, or system contributions that make the work stand out.

multi-GPU
JAX
dense linear algebra
JIT compilation
cuSOLVERMg
🔎 Similar Papers
No similar papers found.