Adjoint sharding for very long context training of state space models

📅 2025-01-01
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address GPU memory explosion and excessive training time in end-to-end training of large language models (LLMs) on million-token-scale contexts (>1M tokens), this paper introduces *adjoint sharding*—a novel gradient computation paradigm mathematically equivalent to backpropagation—along with its truncated variant and distributed parallel implementation. The method synergistically integrates the adjoint method, gradient sharding, and state-space model (SSM) optimizations. Evaluated on a 1.27B-parameter model, it reduces peak training memory by 3× and extends maximum supported context length from 35K to >100K tokens using five AWS p4d instances. This work enables the first efficient, scalable full fine-tuning of LLMs on million-token contexts—breaking the prevailing paradigm that restricts existing approaches to short-context training followed by long-context inference.

Technology Category

Application Category

📝 Abstract
Despite very fast progress, efficiently training large language models (LLMs) in very long contexts remains challenging. Existing methods fall back to training LLMs with short contexts (a maximum of a few thousands tokens in training) and use inference time techniques when evaluating on long contexts (above 1M tokens context window at inference). As opposed to long-context-inference, training on very long context input prompts is quickly limited by GPU memory availability and by the prohibitively long training times it requires on state-of-the-art hardware. Meanwhile, many real-life applications require not only inference but also training/fine-tuning with long context on specific tasks. Such applications include, for example, augmenting the context with various sources of raw reference information for fact extraction, fact summarization, or fact reconciliation tasks. We propose adjoint sharding, a novel technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude, making training on very long context computationally tractable. Adjoint sharding is based on the adjoint method and computes equivalent gradients to backpropagation. We also propose truncated adjoint sharding to speed up the algorithm while maintaining performance. We provide a distributed version, and a paralleled version of adjoint sharding to further speed up training. Empirical results show the proposed adjoint sharding algorithm reduces memory usage by up to 3X with a 1.27B parameter large language model on 1M context length training. This allows to increase the maximum context length during training or fine-tuning of a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.
Problem

Research questions and friction points this paper is trying to address.

Efficiently training large language models in very long contexts
Reducing GPU memory requirements for long context training
Enabling training and fine-tuning on tasks requiring long contexts
Innovation

Methods, ideas, or system contributions that make the work stand out.

Adjoint sharding reduces memory usage for training
Truncated adjoint sharding speeds up algorithm performance
Distributed version further accelerates long context training