🤖 AI Summary
This work addresses the challenge of defect detection in deep learning libraries such as TensorFlow and PyTorch, where complex APIs often lead to subtle bugs and existing testing approaches suffer from high false-positive rates due to imprecise specifications. To overcome this limitation, the authors propose a machine learning classifier that leverages tensor shape abstraction as a precise input representation for API validity constraints. By integrating runtime feedback to automatically generate labeled training data, the method learns accurate usage patterns without relying on manual annotations. Implemented within the ACETest framework, the approach achieves over 91% classification accuracy across 183 APIs and significantly improves test pass rates—from 29% to 61%—demonstrating enhanced precision and scalability in testing deep learning libraries.
📝 Abstract
Deep Learning (DL) libraries like TensorFlow and Pytorch simplify machine learning (ML) model development but are prone to bugs due to their complex design. Bug-finding techniques exist, but without precise API specifications, they produce many false alarms. Existing methods to mine API specifications lack accuracy. We explore using ML classifiers to determine input validity. We hypothesize that tensor shapes are a precise abstraction to encode concrete inputs and capture relationships of the data. Shape abstraction severely reduces problem dimensionality, which is important to facilitate ML training. Labeled data are obtained by observing runtime outcomes on a sample of inputs and classifiers are trained on sets of labeled inputs to capture API constraints. Our evaluation, conducted over 183 APIs from TensorFlow and Pytorch, shows that the classifiers generalize well on unseen data with over 91% accuracy. Integrating these classifiers into the pipeline of ACETest, a SoTA bug-finding technique, improves its pass rate from ~29% to ~61%. Our findings suggest that ML-enhanced input classification is an important aid to scale DL library testing.