🤖 AI Summary
To address the high memory consumption, low computational efficiency, and substantial communication overhead incurred by multi-dimensional Transformers when processing long sequences, this paper proposes Dynamic Sequence Parallelism (DSP), a novel parallelization paradigm. DSP transcends conventional single-dimension sequence parallelism by enabling adaptive switching of parallelization dimensions during computation, coupled with a communication-aware tensor resharding mechanism that supports module-level distributed execution with minimal constraints and maximal flexibility. By tightly integrating multi-dimensional sequence modeling with distributed training optimization, DSP achieves significant improvements over the state-of-the-art embedded sequence parallelism: throughput increases by 32.2%–10×, communication volume decreases by over 75%, and measured communication overhead accounts for less than 25% of total execution time.
📝 Abstract
Scaling multi-dimensional transformers to long sequences is indispensable across various domains. However, the challenges of large memory requirements and slow speeds of such sequences necessitate sequence parallelism. All existing approaches fall under the category of embedded sequence parallelism, which are limited to shard along a single sequence dimension, thereby introducing significant communication overhead. However, the nature of multi-dimensional transformers involves independent calculations across multiple sequence dimensions. To this end, we propose Dynamic Sequence Parallelism (DSP) as a novel abstraction of sequence parallelism. DSP dynamically switches the parallel dimension among all sequences according to the computation stage with efficient resharding strategy. DSP offers significant reductions in communication costs, adaptability across modules, and ease of implementation with minimal constraints. Experimental evaluations demonstrate DSP's superiority over state-of-the-art embedded sequence parallelism methods by remarkable throughput improvements ranging from 32.2% to 10x, with less than 25% communication volume.