ShardTensor: Domain Parallelism for Scientific Machine Learning

📅 2026-05-11
📈 Citations: 0
Influential: 0
📄 PDF

career value

239K/year
🤖 AI Summary
This work addresses the challenges of poor scalability and accuracy degradation in scientific machine learning when handling extremely high-resolution data, particularly due to the absence of a general-purpose parallelization framework supporting sub-unit batch sizes per device. The authors propose ShardTensor, a novel domain-parallel paradigm that shards tensors along spatial domains, thereby decoupling data dimensions from hardware constraints. This approach enables, for the first time, general-purpose parallel training and inference with sub-unit batch sizes. By supporting multidimensional parallelism and integrating dynamic computation–communication load balancing, ShardTensor simultaneously achieves strong scaling—reducing latency—and weak scaling—enabling larger-scale data processing—thereby significantly enhancing the scalability and efficiency of high-fidelity scientific computing tasks.
📝 Abstract
Scientific Machine Learning (SciML) faces unique challenges for extreme-resolution data, with mitigations that often fail to scale or degrade the accuracy of trained models. While some specialized methods have achieved remarkable results in training models or performing inference on massive spatial datasets with bespoke techniques, there is no generalized framework for parallelization over input data below batch size one per device. In this work we introduce ShardTensor: a novel paradigm of domain parallelism that enables flexible scaling of input data to arbitrary sizes. By decoupling the spatial dimensionality of input data from hardware constraints, ShardTensor enables scientific machine learning workloads to reach new levels of high fidelity training and inference. We demonstrate both strong and weak scaling of workloads during training and inference, showing improved latency with strong scaling and demonstrating the capacity to process higher data sizes with weak scaling. Additionally, we demonstrate multiple dimensions of parallelization, removing barriers to SciML on extreme-scale inputs.
Problem

Research questions and friction points this paper is trying to address.

Scientific Machine Learning
extreme-resolution data
domain parallelism
scalability
input data parallelization
Innovation

Methods, ideas, or system contributions that make the work stand out.

ShardTensor
domain parallelism
scientific machine learning
extreme-resolution data
input sharding
🔎 Similar Papers
No similar papers found.
C
Corey Adams
NVIDIA, Santa Clara, CA, USA
Peter Harrington
Peter Harrington
NVIDIA
Deep learningartificial intelligence
A
Akshay Subramaniam
NVIDIA, Santa Clara, CA, USA
M
Mohammad Shoaib Abbas
NVIDIA, Santa Clara, CA, USA
Jaideep Pathak
Jaideep Pathak
NVIDIA
M
Mike Pritchard
NVIDIA, Santa Clara, CA, USA
S
Sanjay Choudhry
NVIDIA, Santa Clara, CA, USA