🤖 AI Summary
This work addresses the degradation of safety alignment in large language models (LLMs) during task-specific fine-tuning, caused by catastrophic forgetting. It is the first to formulate safety preservation as a continual learning (CL) problem. We propose DER—a novel method integrating three CL paradigms: regularization, memory replay, and model merging—specifically designed for fine-tuning-as-a-service settings. DER is rigorously evaluated for robustness against both benign and adversarial user data. Experiments across three major LLM families—LLaMA2, Mistral, and Gemma—and diverse benchmarks (GSM8K, SST-2, Code) demonstrate that DER reduces attack success rates by 37% on average while incurring less than 1.2% degradation in task performance. It significantly outperforms standard fine-tuning and state-of-the-art safety fine-tuning baselines, achieving superior security guarantees without compromising generalization.
📝 Abstract
The safety alignment of large language models (LLMs) is becoming increasingly important with their democratization. In this paper, we study the safety degradation that comes with adapting LLMs to new tasks. We attribute this safety compromise to catastrophic forgetting and frame the problem of preserving safety when fine-tuning as a continual learning (CL) problem. We consider the fine-tuning-as-a-service setup where the user uploads their data to a service provider to get a customized model that excels on the user's selected task. We adapt several CL approaches from the literature and systematically evaluate their ability to mitigate safety degradation. These include regularization-based, memory-based, and model merging approaches. We consider two scenarios, (1) benign user data and (2) poisoned user data. Our results demonstrate that CL approaches consistently achieve lower attack success rates than standard fine-tuning. Among these, DER outperforms both other CL methods and existing safety-preserving baselines while maintaining task utility. These findings generalize across three downstream tasks (GSM8K, SST2, Code) and three model families (LLaMA2-7B, Mistral-7B, Gemma-2B), establishing CL as a practical solution to preserve safety.