🤖 AI Summary
Large language models (LLMs) often learn non-generalizable mechanisms—such as exploiting dataset biases—on specific training data, thereby degrading cross-task transferability. To address this, we propose an attribution-driven neuron pruning framework: first, Integrated Gradients is applied to identify “harmful neurons” that disproportionately influence high-confidence predictions while undermining generalization; second, these neurons are selectively pruned and the model is fine-tuned, compelling it to rely on more robust and task-agnostic representations. Our method requires no additional annotations or architectural modifications. Evaluated on multiple-choice benchmarks, it significantly improves both transfer performance and out-of-distribution robustness—outperforming standard fine-tuning and state-of-the-art pruning baselines. By grounding pruning decisions in interpretable feature attribution, our approach establishes a principled, efficient, and explainable paradigm for enhancing LLM generalization.
📝 Abstract
Large language models (LLMs) often develop learned mechanisms specialized to specific datasets, such as reliance on domain-specific correlations, which yield high-confidence predictions without generalizable reasoning. While beneficial in one setting, these dataset-specific mechanisms typically degrade performance when models encounter novel tasks or distributions. In this work, we introduce a fine-tuning approach designed to enhance generalization by identifying and pruning neurons associated with dataset-specific mechanisms in transformer-based LLMs. Our method employs Integrated Gradients to quantify each neuron's influence on high-confidence predictions, pinpointing those that disproportionately contribute to dataset-specific performance without supporting robust, transferable reasoning. Selectively pruning these neurons compels the model to depend on generalizable representations. Evaluated across multiple-choice benchmarks, our pruning-based fine-tuning significantly enhances performance, surpassing prior (non-pruning) adaptation methods.