🤖 AI Summary
This work addresses the susceptibility of pretrained language models to catastrophic forgetting during fine-tuning or quantization, which often leads to significant degradation of their original capabilities. For the first time, the authors integrate flat minima optimization into the pretraining phase by leveraging Sharpness-Aware Minimization (SAM), a large initial learning rate, and a shortened learning rate annealing schedule to steer convergence toward flatter loss landscapes. This approach substantially enhances model stability and knowledge retention in downstream tasks: it reduces forgetting by an average of 80% across models ranging from 20M to 150M parameters, and when applied mid-training with SAM on the OLMo-2-1B model, it mitigates forgetting by 31% after MetaMath fine-tuning and by 40% following 4-bit quantization.
📝 Abstract
Pretraining optimizers are tuned to produce the strongest possible base model, on the assumption that a stronger starting point yields a stronger model after subsequent changes like post-training and quantization. This overlooks the geometry of the base model which controls how much of the base model's capabilities survive subsequent parameter updates. We study three pretraining optimization approaches that bias optimization toward flatter minima: Sharpness-Aware Minimization (SAM), large learning rates, and shortened learning rate annealing periods. Across model sizes ranging from 20M to 150M parameters, we find that these interventions consistently improve downstream performance after post-training on five common datasets with up to 80% less forgetting. These principles hold at scale: a short SAM mid-training phase applied to an existing OLMo-2-1B checkpoint reduces forgetting by 31% after MetaMath post-training and by 40% after 4-bit quantization.