🤖 AI Summary
To mitigate catastrophic forgetting in fine-tuning large pre-trained models—caused by inaccessibility of original pre-training data and recipes—this paper proposes a dynamic sample weighting method that relies solely on the pre-trained model’s forward loss. By upweighting easy samples, the approach constrains parameter drift and, for the first time, alleviates forgetting directly in the sample space (rather than parameter or gradient space). Theoretical analysis shows that this strategy suppresses overfitting within a subspace and complements existing methods orthogonally. Experiments on cross-modal fine-tuning of Gemma-2B demonstrate that our method incurs only a 0.8% accuracy drop on GSM8K while preserving 5.4% more pre-training task performance compared to baselines. The implementation is publicly available.
📝 Abstract
Fine-tuning a pre-trained model on a downstream task often degrades its original capabilities, a phenomenon known as"catastrophic forgetting". This is especially an issue when one does not have access to the data and recipe used to develop the pre-trained model. Under this constraint, most existing methods for mitigating forgetting are inapplicable. To address this challenge, we propose a sample weighting scheme for the fine-tuning data solely based on the pre-trained model's losses. Specifically, we upweight the easy samples on which the pre-trained model's loss is low and vice versa to limit the drift from the pre-trained model. Our approach is orthogonal and yet complementary to existing methods; while such methods mostly operate on parameter or gradient space, we concentrate on the sample space. We theoretically analyze the impact of fine-tuning with our method in a linear setting, showing that it stalls learning in a certain subspace which inhibits overfitting to the target task. We empirically demonstrate the efficacy of our method on both language and vision tasks. As an example, when fine-tuning Gemma 2 2B on MetaMathQA, our method results in only a $0.8%$ drop in accuracy on GSM8K (another math dataset) compared to standard fine-tuning, while preserving $5.4%$ more accuracy on the pre-training datasets. Our code is publicly available at https://github.com/sanyalsunny111/FLOW_finetuning .