🤖 AI Summary
Low-rank training reduces memory consumption but suffers from performance degradation due to dimensionality reduction in both gradients and weights; full-rank training achieves high accuracy at prohibitive memory cost. Method: We propose Fira, a framework enabling *true* full-rank training—i.e., employing full-rank gradients and full-rank weight updates throughout—under strict low-rank optimizer memory constraints. Fira introduces the first full-rank gradient update mechanism compatible with low-rank optimization, along with gradient norm-adaptive scaling, norm-growth bounding, and low-rank-to-full-rank gradient mapping to mitigate training instability. Contribution/Results: Fira matches or surpasses full-rank baselines in both pretraining and fine-tuning, significantly outperforms LoRA and GaLore, and incurs memory overhead comparable to low-rank methods—achieving, for the first time, the unification of low-rank memory efficiency and full-rank training accuracy.
📝 Abstract
Low-rank training has emerged as a promising approach for reducing memory usage in training Large Language Models (LLMs). Previous methods either rely on decomposing weight matrices (e.g., LoRA), or seek to decompose gradient matrices (e.g., GaLore) to ensure reduced memory consumption. However, both of them constrain the training in a low-rank subspace, thus inevitably leading to sub-optimal performance. This raises a question: whether it is possible to consistently preserve the low-rank constraint for memory efficiency, while achieving full-rank training (i.e., training with full-rank gradients of full-rank weights) to avoid inferior outcomes? In this paper, we propose a new plug-and-play training framework for LLMs called Fira, as the first attempt to achieve this goal. First, we observe an interesting phenomenon during LLM training: the scaling impact of adaptive optimizers (e.g., Adam) on the gradient norm remains similar from low-rank to full-rank training. Based on this observation, we propose a norm-based scaling method, which utilizes the scaling impact of low-rank optimizers as substitutes for that of original full-rank optimizers to enable full-rank training. In this way, we can preserve the low-rank constraint in the optimizer while achieving full-rank training for better performance. Moreover, we find that there are sudden gradient rises during the optimization process, potentially causing loss spikes. To address this, we further put forward a norm-growth limiter to smooth the gradient via regulating the relative increase of gradient norms. Extensive experiments on the pre-training and fine-tuning of LLMs show that Fira outperforms both LoRA and GaLore, achieving performance that is comparable to or even better than full-rank training.