🤖 AI Summary
Existing weight averaging methods (e.g., SWA, LAWA) rely on hand-crafted checkpoint selection strategies and extensive hyperparameter tuning, limiting their generalization and robustness. This paper proposes SeWA, an adaptive weight averaging framework that models discrete checkpoint selection as an end-to-end continuous optimization problem via a Gumbel-Softmax parameterized learnable probability mask, enabling automatic weighted averaging of critical late-stage checkpoints. We derive, for the first time, a stability-based tighter generalization bound for SeWA, theoretically surpassing SGD. Empirically, SeWA achieves significant improvements over SWA and LAWA in behavior cloning, image classification, and text classification—using only a few checkpoints—while enhancing accuracy, accelerating convergence, and drastically reducing human intervention.
📝 Abstract
Weight averaging has become a standard technique for enhancing model performance. However, methods such as Stochastic Weight Averaging (SWA) and Latest Weight Averaging (LAWA) often require manually designed procedures to sample from the training trajectory, and the results depend heavily on hyperparameter tuning. To minimize human effort, this paper proposes a simple yet efficient algorithm called Selective Weight Averaging (SeWA), which adaptively selects checkpoints during the final stages of training for averaging. Based on SeWA, we show that only a few points are needed to achieve better generalization and faster convergence. Theoretically, solving the discrete subset selection problem is inherently challenging. To address this, we transform it into a continuous probabilistic optimization framework and employ the Gumbel-Softmax estimator to learn the non-differentiable mask for each checkpoint. Further, we theoretically derive the SeWA's stability-based generalization bounds, which are sharper than that of SGD under both convex and non-convex assumptions. Finally, solid extended experiments in various domains, including behavior cloning, image classification, and text classification, further validate the effectiveness of our approach.