🤖 AI Summary
To address worker- and step-level computational and communication load imbalance in distributed training of large language models (LLMs) on ultra-long contexts (e.g., 512K tokens), caused by dynamic sparse attention, this paper proposes MTraining. Methodologically, MTraining introduces (1) a dynamic sparse training paradigm that adaptively adjusts sparsity patterns to align with both sequence length and hardware topology; and (2) a hybrid attention mechanism integrating sparse ring attention and hierarchical sparse ring attention to jointly optimize cross-worker communication efficiency and per-step computational load. Evaluated on a 32-GPU A100 cluster, MTraining extends the context window of Qwen2.5-3B from 32K to 512K while achieving up to 6× higher training throughput—without any degradation in model accuracy.
📝 Abstract
The adoption of long context windows has become a standard feature in Large Language Models (LLMs), as extended contexts significantly enhance their capacity for complex reasoning and broaden their applicability across diverse scenarios. Dynamic sparse attention is a promising approach for reducing the computational cost of long-context. However, efficiently training LLMs with dynamic sparse attention on ultra-long contexts-especially in distributed settings-remains a significant challenge, due in large part to worker- and step-level imbalance. This paper introduces MTraining, a novel distributed methodology leveraging dynamic sparse attention to enable efficient training for LLMs with ultra-long contexts. Specifically, MTraining integrates three key components: a dynamic sparse training pattern, balanced sparse ring attention, and hierarchical sparse ring attention. These components are designed to synergistically address the computational imbalance and communication overheads inherent in dynamic sparse attention mechanisms during the training of models with extensive context lengths. We demonstrate the efficacy of MTraining by training Qwen2.5-3B, successfully expanding its context window from 32K to 512K tokens on a cluster of 32 A100 GPUs. Our evaluations on a comprehensive suite of downstream tasks, including RULER, PG-19, InfiniteBench, and Needle In A Haystack, reveal that MTraining achieves up to a 6x higher training throughput while preserving model accuracy. Our code is available at https://github.com/microsoft/MInference/tree/main/MTraining.