π€ AI Summary
This work addresses the challenge that implicit biases in training data can form spurious correlations with target tasks during fine-tuning of pretrained language models, undermining model fairness and out-of-distribution generalization. The authors propose GRASP, a novel method that, for the first time, identifies latent factors responsible for such spurious associations directly from LoRA fine-tuning weights without supervision. By leveraging gradient projection and residual stream modulation, GRASP severs the modelβs reliance on these factors without removing them from the input. Experiments demonstrate that GRASP completely eliminates or reduces misaligned behaviors by approximately fivefold in unsafe code generation and erroneous medical advice tasks, respectively, and cuts topic-irrelevant political bias drift by over 50%, all while improving main-task performance and substantially outperforming existing baselines.
π Abstract
Fine-tuning a pretrained language model on a curated dataset can produce spurious correlations between the fine-tuning task and unintended latent factors -- such as misaligned personas or political slant -- that the curation procedure has entangled with the task. The model can latch onto these spurious correlations, leading to bias and reduced out-of-distribution generalisation. We prove that under reasonable assumptions on task complexity and the spurious correlation, such latent factors can be identified, without supervision, from the weights of a naive LoRA fine-tune. Existing approaches to removing bias, such as activation steering, remove identified factors from residual-stream activations, either at inference or during training. We argue, however, that the goal should be to remove the spurious correlation, not the latent factor itself, as the pretrained model may rely on it for genuine task signal. To enable this, we propose GRASP, GRadient projection of Associated Spurious Patterns, which prevents the model from acquiring new reliance on the identified latent factor while preserving any pretrained content along it. We validate on three fine-tuning tasks. The first two involve emergent misalignment, where fine-tuning on a narrow task -- in our case, writing insecure code and giving bad medical advice -- leads to misaligned responses on unrelated topics. Here our method completely removes misalignment in the insecure code case and reduces them by ~5x in the bad medical advice case, beating all baselines in the trade-off between misalignment-reduction and task-preservation. The last is a novel political-bias experiment, where fine-tuning on right-skewed Reddit financial-advice data causes political-lean drift on unrelated topics. Here our method reduces drift by more than half, while improving financial task performance, beating all baselines.