Trainable Weight Averaging: Accelerating Training and Improving Generalization

๐Ÿ“… 2022-05-26
๐Ÿ“ˆ Citations: 0
โœจ Influential: 0
๐Ÿ“„ PDF
๐Ÿค– AI Summary
Existing weight averaging methods (e.g., SWA) rely on pre-specified weighting schemes, yield limited generalization gains, and scale poorly to large-scale training. To address these limitations, this paper proposes Trainable Weight Averaging (TWA), which formulates weight averaging as a differentiable optimization problem within a learned subspace to jointly learn optimal convex combination coefficients over candidate model weights. We introduce two complementary paradigmsโ€”TWA-t (training-driven) and TWA-v (validation-driven)โ€”and integrate a distributed low-bit projection matrix compression framework to enable memory-efficient, parallelizable training. Crucially, TWA is the first method to cast weight averaging as an end-to-end learnable process, eliminating reliance on handcrafted weighting heuristics. Experiments on CIFAR-10/100 and ImageNet demonstrate that TWA achieves 40% and 30% training speedup, respectively, without accuracy loss; during fine-tuning, it consistently outperforms SWA in generalization. The implementation is publicly available.
๐Ÿ“ Abstract
Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40% on CIFAR datasets and 30% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.
Problem

Research questions and friction points this paper is trying to address.

Optimizes weight averaging in neural networks
Reduces training time significantly
Enhances generalization performance
Innovation

Methods, ideas, or system contributions that make the work stand out.

Trainable Weight Averaging optimizes weighting coefficients
Distributed framework integrates parallel computation, compression
TWA-v uses validation data for effective averaging
๐Ÿ”Ž Similar Papers
No similar papers found.