🤖 AI Summary
This work addresses the vulnerability of models to spurious correlations under distribution shifts, which degrades out-of-distribution generalization. To mitigate this issue, the authors propose Hybrid Prompt Arithmetic (HyPA), a parameter-efficient approach that combines task-specific soft prompts with linearized confounding prompts through prompt arithmetic operations, explicitly reducing the model’s reliance on confounding variables. Evaluated across multiple out-of-distribution generalization benchmarks, HyPA achieves a superior trade-off between robustness and performance. Representation analysis further confirms that the method effectively suppresses confounding signals in hidden layers. These findings offer a novel direction for causal robust learning within the prompt tuning framework.
📝 Abstract
In classification tasks, models may rely on confounding variables to achieve strong in-distribution performance, capturing spurious features that fail under distribution shift. This shortcut behavior leads to substantial degradation in out-of-distribution settings. Task arithmetic offers a potential solution by removing unwanted signals via subtraction of secondary model updates, but it typically requires full fine-tuning, which is computationally expensive. Prompt tuning provides a parameter-efficient alternative by adapting models through a small set of trainable virtual tokens. Task arithmetic on the resulting prompts presents an appealing alternative to operations on entire models, but the extent to which this approach can limit reliance on spurious features remains to be established. In this work, we study whether composing soft prompts through task arithmetic improves robustness to confounding shifts. We propose Hybrid Prompt Arithmetic (HyPA), which combines task prompts with linearized confounder prompts to counteract spurious correlations. Across multiple benchmarks, HyPA consistently improves the robustness-performance trade-off relative to prompt-arithmetic baselines under distribution shift. We further analyze how HyPA affects hidden representations and find evidence consistent with it mitigating confounding either by reducing the influence of confounder signals on predictions or by suppressing them in the representation. These results establish HyPA as a parameter-efficient and promising approach for improving robustness under confounding shifts in the evaluated setting.