🤖 AI Summary
This study systematically evaluates the end-to-end performance of three major deep learning frameworks—Keras, PyTorch, and JAX—on medical image classification. Using the PathMNIST dataset and a unified lightweight CNN architecture, we conduct a standardized, reproducible comparison under identical hardware and software conditions, measuring training efficiency (time per epoch, GPU memory consumption), test-set classification accuracy, and inference latency. To our knowledge, this is the first such comprehensive, application-specific benchmark for medical imaging. Results reveal critical trade-offs across accuracy, speed, and resource utilization: PyTorch achieves the best balance of accuracy and framework flexibility; JAX delivers the fastest training throughput under large-batch regimes; and Keras enables the most streamlined deployment, albeit with limited extensibility. The findings provide empirically grounded guidance for framework selection in clinical AI system development.
📝 Abstract
Deep learning has significantly advanced the field of medical image classification, particularly with the adoption of Convolutional Neural Networks (CNNs). Various deep learning frameworks such as Keras, PyTorch and JAX offer unique advantages in model development and deployment. However, their comparative performance in medical imaging tasks remains underexplored. This study presents a comprehensive analysis of CNN implementations across these frameworks, using the PathMNIST dataset as a benchmark. We evaluate training efficiency, classification accuracy and inference speed to assess their suitability for real-world applications. Our findings highlight the trade-offs between computational speed and model accuracy, offering valuable insights for researchers and practitioners in medical image analysis.