SeWA: Selective Weight Average via Probabilistic Masking

📅 2025-02-14
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Existing weight averaging methods (e.g., SWA, LAWA) rely on hand-crafted checkpoint selection strategies and extensive hyperparameter tuning, limiting their generalization and robustness. This paper proposes SeWA, an adaptive weight averaging framework that models discrete checkpoint selection as an end-to-end continuous optimization problem via a Gumbel-Softmax parameterized learnable probability mask, enabling automatic weighted averaging of critical late-stage checkpoints. We derive, for the first time, a stability-based tighter generalization bound for SeWA, theoretically surpassing SGD. Empirically, SeWA achieves significant improvements over SWA and LAWA in behavior cloning, image classification, and text classification—using only a few checkpoints—while enhancing accuracy, accelerating convergence, and drastically reducing human intervention.

Technology Category

Application Category

📝 Abstract
Weight averaging has become a standard technique for enhancing model performance. However, methods such as Stochastic Weight Averaging (SWA) and Latest Weight Averaging (LAWA) often require manually designed procedures to sample from the training trajectory, and the results depend heavily on hyperparameter tuning. To minimize human effort, this paper proposes a simple yet efficient algorithm called Selective Weight Averaging (SeWA), which adaptively selects checkpoints during the final stages of training for averaging. Based on SeWA, we show that only a few points are needed to achieve better generalization and faster convergence. Theoretically, solving the discrete subset selection problem is inherently challenging. To address this, we transform it into a continuous probabilistic optimization framework and employ the Gumbel-Softmax estimator to learn the non-differentiable mask for each checkpoint. Further, we theoretically derive the SeWA's stability-based generalization bounds, which are sharper than that of SGD under both convex and non-convex assumptions. Finally, solid extended experiments in various domains, including behavior cloning, image classification, and text classification, further validate the effectiveness of our approach.
Problem

Research questions and friction points this paper is trying to address.

Adaptive checkpoint selection for weight averaging.
Reduces manual effort in hyperparameter tuning.
Enhances model generalization and convergence speed.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Selective Weight Averaging algorithm
Gumbel-Softmax estimator usage
Stability-based generalization bounds
🔎 Similar Papers
No similar papers found.
P
Peng Wang
School of Mathematics and Statistics, Huazhong University of Science and Technology, Wuhan, China
Shengchao Hu
Shengchao Hu
Shanghai Jiao Tong University
machine learningreinforcement learning
Z
Zerui Tao
RIKEN Center for Advanced Intelligence Project, Tokyo, Japan
G
Guoxia Wang
Baidu Inc., Beijing, China
Dianhai Yu
Dianhai Yu
Baidu
Deep LearningNatural Language ProcessingMachine LearningArtificial intelligence
L
Li Shen
School of Cyber Science & Technology, Shenzhen Campus of Sun Yat-sen University, Shenzhen, China
Quan Zheng
Quan Zheng
Institute of Software, Chinese Academy of Sciences
Computer Graphics
Dacheng Tao
Dacheng Tao
Nanyang Technological University
artificial intelligencemachine learningcomputer visionimage processingdata mining