Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning

๐Ÿ“… 2024-10-20
๐Ÿ›๏ธ arXiv.org
๐Ÿ“ˆ Citations: 1
โœจ Influential: 0
๐Ÿ“„ PDF
๐Ÿค– AI Summary
Sequential supervised fine-tuning (SFT) followed by preference learning (e.g., DPO or RLHF) in large language model post-training induces catastrophic forgetting, degrading SFT task performance while optimizing for preference alignment. Method: This work theoretically establishes the suboptimality of sequential training and proposes the first jointly optimized framework with provable convergence guarantees. It unifies SFT and DPO objectives via a scalable multi-objective loss function and performs joint gradient updatesโ€”without additional computational overhead. Contribution/Results: The method enables effective knowledge fusion across both stages. Experiments demonstrate a 23% improvement in SFT task retention and a 9.7% gain in preference alignment accuracy over sequential baselines, while maintaining comparable computational cost.

Technology Category

Application Category

๐Ÿ“ Abstract
Post-training of pre-trained LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning (RLHF or DPO) stage, is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, sequential training is sub-optimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. We theoretically prove the sub-optimality of sequential post-training. Furthermore, we propose a practical joint post-training framework with theoretical convergence guarantees and empirically outperforms sequential post-training framework, while having similar computational cost. Our code is available at https://github.com/heshandevaka/XRIGHT.
Problem

Research questions and friction points this paper is trying to address.

Mitigate forgetting in LLM post-training
Optimize SFT and RLHF/DPO trade-off
Propose joint post-training framework
Innovation

Methods, ideas, or system contributions that make the work stand out.

Joint post-training framework
Mitigates forgetting in LLMs
Theoretical convergence guarantees
๐Ÿ”Ž Similar Papers