🤖 AI Summary
In PyTorch 2, FX graph breaks—caused primarily by dynamic control flow and unsupported Python I/O operations—fragment computational graphs, trigger frequent fallbacks to eager mode, exacerbate CPU-GPU synchronization overhead, and hinder graph-level optimizations. Method: We propose a compiler frontend source-to-source transformation technique that automatically analyzes and rewrites Python code prior to TorchDynamo’s graph capture. Leveraging the Jac framework, our approach integrates static analysis with targeted source rewriting to eliminate *repairable* graph breaks, enabling seamless interoperability between TorchDynamo and TorchInductor. Contribution/Results: Evaluated on eight Hugging Face models, our method achieves zero graph breaks for six models, reduces inference latency by up to 75%, and improves end-to-end throughput by up to 8%. This work is the first to systematically incorporate source-level transformations into the PyTorch JIT compilation pipeline, significantly expanding both the scope of compilable graphs and the applicability of downstream optimizations.
📝 Abstract
This paper presents GraphMend, a high-level compiler that eliminates FX graph breaks in PyTorch 2 programs. Although PyTorch 2 introduced TorchDynamo and TorchInductor to enable just-in-time graph compilation, unresolved dynamic control flow and unsupported Python constructs often fragment models into multiple FX graphs. These fragments force frequent fallbacks to eager mode, incur costly CPU-to-GPU synchronizations, and reduce optimization opportunities. GraphMend addresses this limitation by analyzing and transforming source code before execution. Built on the Jac compilation framework, GraphMend introduces two code transformations that remove graph breaks due to dynamic control flow and Python I/O functions. This design allows PyTorch's compilation pipeline to capture larger, uninterrupted FX graphs without requiring manual refactoring by developers. Evaluation across eight Hugging Face models shows that GraphMend removes all fixable graph breaks due to dynamic control flow and Python I/O functions, driving the break count to 0 in 6 models and reducing it from 5 to 2 in another model. On NVIDIA RTX 3090 and A40 GPUs, GraphMend achieves up to 75% latency reductions and up to 8% higher end-to-end throughput. These results demonstrate that high-level code transformation is an effective complement to PyTorch's dynamic JIT compilation pipeline, substantially improving both usability and performance.