🤖 AI Summary
Frequent hardware failures during distributed training of large language models (LLMs) severely degrade training stability and resource utilization, while existing fault-tolerance mechanisms incur substantial computational and memory overhead. Method: This paper proposes an efficient fault-tolerant optimization framework integrating three novel techniques: (1) skip connections to bypass multi-head attention (MHA) backward propagation for accelerated recovery; (2) on-the-fly re-computation of feed-forward network (FFN) activation values to reduce GPU memory footprint; and (3) low-rank gradient approximation to accelerate weight updates. At the algorithmic level, it enables seamless task migration to neighboring nodes upon failure, preserving theoretical convergence at O(1/√T), matching standard distributed SGD. Contribution/Results: Experiments demonstrate only a 4.18% throughput degradation under high failure rates, with fault-tolerance efficiency 5.0–6.7× higher than state-of-the-art methods—significantly improving training robustness and hardware utilization.
📝 Abstract
As distributed optimization scales to meet the demands of Large Language Model (LLM) training, hardware failures become increasingly non-negligible. Existing fault-tolerant training methods often introduce significant computational or memory overhead, demanding additional resources. To address this challenge, we propose Memory- and Computation-efficient Fault-tolerant Optimization (MeCeFO), a novel algorithm that ensures robust training with minimal overhead. When a computing node fails, MeCeFO seamlessly transfers its training task to a neighboring node while employing memory- and computation-efficient algorithmic optimizations to minimize the extra workload imposed on the neighboring node handling both tasks. MeCeFO leverages three key algorithmic designs: (i) Skip-connection, which drops the multi-head attention (MHA) module during backpropagation for memory- and computation-efficient approximation; (ii) Recomputation, which reduces activation memory in feedforward networks (FFNs); and (iii) Low-rank gradient approximation, enabling efficient estimation of FFN weight matrix gradients. Theoretically, MeCeFO matches the convergence rate of conventional distributed training, with a rate of $mathcal{O}(1/sqrt{nT})$, where n is the data parallelism size and T is the number of iterations. Empirically, MeCeFO maintains robust performance under high failure rates, incurring only a 4.18% drop in throughput, demonstrating 5.0$ imes$ to 6.7$ imes$ greater resilience than previous SOTA approaches. Codes are available at https://github.com/pkumelon/MeCeFO.