Learning Bug Context for PyTorch-to-JAX Translation with LLMs

📅 2025-10-10
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Addressing core challenges in PyTorch-to-JAX code translation—including significant framework semantic discrepancies, scarcity of parallel corpora, and weak evaluation methodologies—this paper proposes T2J, a novel translation framework. T2J constructs a high-quality error-context dataset via human-in-the-loop iterative repair and introduces structured prompting enhancement alongside an LLM-as-judge automated evaluation mechanism. It pioneers three domain-specific metrics: the T2J CodeTrans Score (measuring functional correctness and syntactic fidelity), the FixCost Score (quantifying post-translation manual effort), and runtime performance gain (assessing execution speedup). Experiments demonstrate that T2J outperforms baselines by +10% CodeBLEU, −50% FixCost Score, +1.33 CodeTrans Score, 2× higher functional correctness rate, and up to 2.5× faster generated-code execution.

Technology Category

Application Category

📝 Abstract
Despite recent progress of large language models (LLMs) on code translation among mainstream languages, translating PyTorch to JAX remains nontrivial. The two libraries, though both embedded in Python, differ in core design, execution semantics, and ecosystem maturity; JAX is newer and comparatively underrepresented in public code, and parallel PyTorch--JAX corpora are limited. Weaknesses in existing evaluation further complicate cross-framework benchmarking. We present T2J, a prompt-augmentation framework that strengthens LLM-based PyTorch to JAX translation. Our pipeline (i) assembles two PyTorch sources -- the problem-solving set from TorchLeet (Aroori & Chien, 2025) and a GitHub-derived set from CodeParrot (Wolf et al., 2022) -- and uses GPT-4o-mini to produce initial JAX drafts; (ii) engages two professional developers to iteratively repair those drafts until functional equivalence, yielding a curated fixed-bug dataset of common errors and patches; and (iii) constructs augmented prompts that inject structured guidance from these fixes to steer lightweight LLMs (e.g., GPT-4o-mini). We also introduce three metrics tailored to PyTorch to JAX: T2J CodeTrans Score, T2J FixCost Score (an LLM-based estimate of bug-fix effort), and T2J Comparison Score (LLM-as-judge). Empirically, T2J raises GPT-4o-mini performance by up to 10% on CodeBLEU, 50% on T2J FixCost Score, 1.33 points on T2J CodeTrans Score (0--4 scale), and 100% on T2J Comparison Score; moreover, the generated code runs up to 2.5x faster than the baseline.
Problem

Research questions and friction points this paper is trying to address.

Translating PyTorch to JAX remains challenging due to design differences
Limited parallel PyTorch-JAX corpora and weak evaluation benchmarks exist
Existing methods produce common errors requiring significant manual repair
Innovation

Methods, ideas, or system contributions that make the work stand out.

Prompt-augmentation framework for PyTorch-to-JAX translation
Curated fixed-bug dataset from iterative developer repairs
Three specialized metrics for cross-framework code evaluation
🔎 Similar Papers
No similar papers found.