Avoiding spurious sharpness minimization broadens applicability of SAM

📅 2025-02-04
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Sharpness-Aware Minimization (SAM) often degrades performance in NLP tasks because it implicitly regularizes logit statistics rather than optimizing the geometric structure of the model function, thereby introducing spurious sharpness. Method: We formally distinguish between function-level and logit-level sources of sharpness—introducing Functional-SAM, which decouples logit-statistic interference via curvature regularization in function space to avoid spurious sharpness minimization; we further design a gradient-preconditioned perturbation strategy to enhance perturbation efficacy. Contribution/Results: Functional-SAM is compatible with Chinchilla-style training and consistently outperforms AdamW and SAM baselines across multi-scale models—including billion-parameter architectures—achieving superior generalization under both fixed-step and Chinchilla-compliant training budgets.

Technology Category

Application Category

📝 Abstract
Curvature regularization techniques like Sharpness Aware Minimization (SAM) have shown great promise in improving generalization on vision tasks. However, we find that SAM performs poorly in domains like natural language processing (NLP), often degrading performance -- even with twice the compute budget. We investigate the discrepancy across domains and find that in the NLP setting, SAM is dominated by regularization of the logit statistics -- instead of improving the geometry of the function itself. We use this observation to develop an alternative algorithm we call Functional-SAM, which regularizes curvature only through modification of the statistics of the overall function implemented by the neural network, and avoids spurious minimization through logit manipulation. Furthermore, we argue that preconditioning the SAM perturbation also prevents spurious minimization, and when combined with Functional-SAM, it gives further improvements. Our proposed algorithms show improved performance over AdamW and SAM baselines when trained for an equal number of steps, in both fixed-length and Chinchilla-style training settings, at various model scales (including billion-parameter scale). On the whole, our work highlights the importance of more precise characterizations of sharpness in broadening the applicability of curvature regularization to large language models (LLMs).
Problem

Research questions and friction points this paper is trying to address.

Improves generalization in NLP tasks
Avoids spurious sharpness minimization
Enhances curvature regularization techniques
Innovation

Methods, ideas, or system contributions that make the work stand out.

Functional-SAM regularizes curvature differently
Preconditioning SAM perturbation prevents spurious minimization
Improved performance across various model scales
🔎 Similar Papers
No similar papers found.