🤖 AI Summary
AlphaFold3-like models suffer from severe training scalability limitations due to computationally intensive 2D attention, high overhead from retrieval-augmented data flows, and GPU memory bottlenecks. This work proposes a system-level optimization framework: (1) eliminating GPU idle time via pre-caching; (2) designing a Triton-accelerated, heterogeneous EvoAttention kernel; and (3) deeply fusing critical small operators to reduce kernel launch overhead. Evaluated on NVIDIA H200 and AMD MI250 platforms, the approach achieves end-to-end training acceleration—reducing peak GPU memory consumption by 1.23×, accelerating per-step iteration by 1.73× (H200) and 1.62× (MI250), and extending maximum supported sequence length by 1.35× without out-of-memory errors. To our knowledge, this is the first solution enabling efficient, high-throughput, cross-architecture training of AF3-scale models with robust long-sequence support.
📝 Abstract
Protein structure prediction models such as AlphaFold3 (AF3) push the frontier of biomolecular modeling by incorporating science-informed architectural changes to the transformer architecture. However, these advances come at a steep system cost, introducing: compute- and memory-intensive operators, 2D attention mechanisms, and retrieval-augmented data pipelines, which collectively hinder the scalability of AF3 training. In this work, we present MegaFold, a cross-platform system to accelerate AF3 training. MegaFold tackles key bottlenecks through ahead-of-time caching to eliminate GPU idle time from the retrieval-augmented data pipeline, Triton-based kernels for memory-efficient EvoAttention on heterogeneous devices, and deep fusion for common and critical small operators in AF3. Evaluation on both NVIDIA H200 and AMD MI250 GPUs shows that MegaFold reduces peak memory usage of AF3 training by up to 1.23$ imes$ and improves per-iteration training time by up-to 1.73$ imes$ and 1.62$ imes$ respectively. More importantly, MegaFold enables training on 1.35$ imes$ longer sequence lengths compared to PyTorch baselines without running out-of-memory, significantly improving the scalability of modern protein folding models. We open source our code at https://github.com/Supercomputing-System-AI-Lab/MegaFold/.