🤖 AI Summary
This work addresses the high memory footprint of intermediate activations in large language model training, which limits scalability and efficiency. The authors propose an online activation subspace learning method that dynamically maintains a low-dimensional subspace during training, compressing activations by projecting them into this subspace while preserving the original forward computation. Gradients and optimizer states are also maintained within the subspace. Key innovations include the first continuous online update of the activation subspace throughout training and a projection-aware optimizer designed to ensure training stability during subspace transitions. Experiments demonstrate that the method reduces peak memory usage to half that of full fine-tuning across various pretraining and fine-tuning tasks, achieving comparable performance to full fine-tuning and outperforming existing low-rank approaches.
📝 Abstract
Training large language models (LLMs) is constrained by memory requirements, with activations accounting for a substantial fraction of the total footprint. Existing approaches reduce memory using low-rank weight parameterizations or low-rank gradient subspaces for optimizer states, while activation memory is addressed through architectural modifications or compression schemes based on periodically updated projections. We propose OASIS, an online activation subspace learning algorithm for memory-efficient training that tracks and continuously updates a low-dimensional activation subspace during training. Intermediate activations are projected onto this evolving subspace, reducing memory without modifying forward-pass computations. The evolving activation subspace induces low-rank gradient representations, enabling both gradients and optimizer states to be maintained directly in this subspace, while a projection-aware optimizer consistently transports optimizer states across subspace updates for stable training. Across various finetuning and pretraining tasks, OASIS achieves up to $2\times$ lower peak memory than full fine-tuning while matching its performance and outperforming prior low-rank methods.