๐ค AI Summary
Existing weight averaging methods (e.g., SWA) rely on pre-specified weighting schemes, yield limited generalization gains, and scale poorly to large-scale training. To address these limitations, this paper proposes Trainable Weight Averaging (TWA), which formulates weight averaging as a differentiable optimization problem within a learned subspace to jointly learn optimal convex combination coefficients over candidate model weights. We introduce two complementary paradigmsโTWA-t (training-driven) and TWA-v (validation-driven)โand integrate a distributed low-bit projection matrix compression framework to enable memory-efficient, parallelizable training. Crucially, TWA is the first method to cast weight averaging as an end-to-end learnable process, eliminating reliance on handcrafted weighting heuristics. Experiments on CIFAR-10/100 and ImageNet demonstrate that TWA achieves 40% and 30% training speedup, respectively, without accuracy loss; during fine-tuning, it consistently outperforms SWA in generalization. The implementation is publicly available.
๐ Abstract
Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40% on CIFAR datasets and 30% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.