MISA: Memory-Efficient LLMs Optimization with Module-wise Importance Sampling

📅 2025-10-28
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address the high memory overhead in large language model (LLM) training and the limitations of existing hierarchical optimization methods—which ignore intra-layer module heterogeneity and achieve only marginal memory savings—this paper proposes Modular Importance Sampling (MIS). MIS decomposes Transformer layers into fine-grained, independently activatable/frozen modules and introduces, for the first time, module-level importance scoring coupled with weighted stochastic sampling to dynamically control gradient variance. We theoretically establish an $O(1/sqrt{K})$ convergence rate for MIS under non-convex optimization. Empirical results demonstrate that MIS significantly reduces GPU memory consumption—by up to 42%—across diverse tasks, while preserving or even improving model performance over state-of-the-art baselines. The implementation is publicly available.

Technology Category

Application Category

📝 Abstract
The substantial memory demands of pre-training and fine-tuning large language models (LLMs) require memory-efficient optimization algorithms. One promising approach is layer-wise optimization, which treats each transformer block as a single layer and optimizes it sequentially, while freezing the other layers to save optimizer states and activations. Although effective, these methods ignore the varying importance of the modules within each layer, leading to suboptimal performance. Moreover, layer-wise sampling provides only limited memory savings, as at least one full layer must remain active during optimization. To overcome these limitations, we propose Module-wise Importance SAmpling (MISA), a novel method that divides each layer into smaller modules and assigns importance scores to each module. MISA uses a weighted random sampling mechanism to activate modules, provably reducing gradient variance compared to layer-wise sampling. Additionally, we establish an (mathcal{O}(1/sqrt{K})) convergence rate under non-convex and stochastic conditions, where $K$ is the total number of block updates, and provide a detailed memory analysis showcasing MISA's superiority over existing baseline methods. Experiments on diverse learning tasks validate the effectiveness of MISA. Source code is available at https://github.com/pkumelon/MISA.
Problem

Research questions and friction points this paper is trying to address.

Reduces memory demands in LLM pre-training and fine-tuning optimization
Overcomes limitations of layer-wise methods by module importance sampling
Improves performance while minimizing optimizer states and activations memory
Innovation

Methods, ideas, or system contributions that make the work stand out.

Divides layers into modules with importance scores
Uses weighted random sampling to activate modules
Reduces gradient variance and memory usage
🔎 Similar Papers
No similar papers found.
Yuxi Liu
Yuxi Liu
University of California, Berkeley
general relativityquantum mechanicsneural network
R
Renjia Deng
Peking University
Y
Yutong He
Peking University
X
Xue Wang
Alibaba DAMO Academy
Tao Yao
Tao Yao
Alibaba
Operations ResearchMachine LearningAnalyticsStatisticsTransportation
K
Kun Yuan
Peking University