🤖 AI Summary
To address the high memory overhead in large language model (LLM) training and the limitations of existing hierarchical optimization methods—which ignore intra-layer module heterogeneity and achieve only marginal memory savings—this paper proposes Modular Importance Sampling (MIS). MIS decomposes Transformer layers into fine-grained, independently activatable/frozen modules and introduces, for the first time, module-level importance scoring coupled with weighted stochastic sampling to dynamically control gradient variance. We theoretically establish an $O(1/sqrt{K})$ convergence rate for MIS under non-convex optimization. Empirical results demonstrate that MIS significantly reduces GPU memory consumption—by up to 42%—across diverse tasks, while preserving or even improving model performance over state-of-the-art baselines. The implementation is publicly available.
📝 Abstract
The substantial memory demands of pre-training and fine-tuning large language models (LLMs) require memory-efficient optimization algorithms. One promising approach is layer-wise optimization, which treats each transformer block as a single layer and optimizes it sequentially, while freezing the other layers to save optimizer states and activations. Although effective, these methods ignore the varying importance of the modules within each layer, leading to suboptimal performance. Moreover, layer-wise sampling provides only limited memory savings, as at least one full layer must remain active during optimization. To overcome these limitations, we propose Module-wise Importance SAmpling (MISA), a novel method that divides each layer into smaller modules and assigns importance scores to each module. MISA uses a weighted random sampling mechanism to activate modules, provably reducing gradient variance compared to layer-wise sampling. Additionally, we establish an (mathcal{O}(1/sqrt{K})) convergence rate under non-convex and stochastic conditions, where $K$ is the total number of block updates, and provide a detailed memory analysis showcasing MISA's superiority over existing baseline methods. Experiments on diverse learning tasks validate the effectiveness of MISA. Source code is available at https://github.com/pkumelon/MISA.