SoftJAX & SoftTorch: Empowering Automatic Differentiation Libraries with Informative Gradients

📅 2026-03-09
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the challenge posed by non-differentiable “hard” operations—such as thresholding, Boolean logic, discrete indexing, and sorting—in mainstream automatic differentiation frameworks, which yield zero or undefined gradients and impede end-to-end optimization. To overcome this limitation, the authors introduce SoftJAX and SoftTorch: two open-source, plug-and-play libraries that systematically integrate techniques from fuzzy logic, optimal transport, permutahedron projections, and straight-through gradient estimators to provide differentiable surrogates for JAX and PyTorch, respectively. This study presents the first unified framework encompassing a broad range of smoothing relaxations across element-wise, Boolean/indexing, and axis-wise operations, significantly enhancing the expressiveness of differentiable programming while preserving API compatibility. Experiments demonstrate that these libraries effectively propagate informative gradients through otherwise non-differentiable components, enabling their inclusion in end-to-end training pipelines. The code is publicly released.

Technology Category

Application Category

📝 Abstract
Automatic differentiation (AD) frameworks such as JAX and PyTorch have enabled gradient-based optimization for a wide range of scientific fields. Yet, many "hard" primitives in these libraries such as thresholding, Boolean logic, discrete indexing, and sorting operations yield zero or undefined gradients that are not useful for optimization. While numerous "soft" relaxations have been proposed that provide informative gradients, the respective implementations are fragmented across projects, making them difficult to combine and compare. This work introduces SoftJAX and SoftTorch, open-source, feature-complete libraries for soft differentiable programming. These libraries provide a variety of soft functions as drop-in replacements for their hard JAX and PyTorch counterparts. This includes (i) elementwise operators such as clip or abs, (ii) utility methods for manipulating Booleans and indices via fuzzy logic, (iii) axiswise operators such as sort or rank -- based on optimal transport or permutahedron projections, and (iv) offer full support for straight-through gradient estimation. Overall, SoftJAX and SoftTorch make the toolbox of soft relaxations easily accessible to differentiable programming, as demonstrated through benchmarking and a practical case study. Code is available at github.com/a-paulus/softjax and github.com/a-paulus/softtorch.
Problem

Research questions and friction points this paper is trying to address.

automatic differentiation
hard primitives
informative gradients
soft relaxations
differentiable programming
Innovation

Methods, ideas, or system contributions that make the work stand out.

soft differentiable programming
informative gradients
automatic differentiation
gradient relaxation
straight-through estimator
🔎 Similar Papers
No similar papers found.