🤖 AI Summary
Existing sparse training methods improve inference efficiency but still rely on dense weights or gradients during training, compromising both computational efficiency and accuracy.
Method: We propose a fully sparse training framework that eliminates all dense intermediate representations. It employs guided stochastic exploration to dynamically grow critical connections, integrates gradient-driven structural search, and performs backpropagation without dense operations—achieving high sparsity (>95%) and high performance simultaneously in both training and inference.
Contribution/Results: Our method is the first to achieve linear time complexity with increasing model width, significantly reducing memory footprint and energy consumption. Evaluated on CIFAR-10/100 and ImageNet across ResNet, VGG, and ViT architectures, it surpasses state-of-the-art sparse training methods in accuracy, accelerates training by up to 2.3×, and reduces GPU memory usage by up to 4.1×.
📝 Abstract
The excessive computational requirements of modern artificial neural networks (ANNs) are posing limitations on the machines that can run them. Sparsification of ANNs is often motivated by time, memory and energy savings only during model inference, yielding no benefits during training. A growing body of work is now focusing on providing the benefits of model sparsification also during training. While these methods greatly improve the training efficiency, the training algorithms yielding the most accurate models still materialize the dense weights, or compute dense gradients during training. We propose an efficient, always-sparse training algorithm with excellent scaling to larger and sparser models, supported by its linear time complexity with respect to the model width during training and inference. Moreover, our guided stochastic exploration algorithm improves over the accuracy of previous sparse training methods. We evaluate our method on CIFAR-10/100 and ImageNet using ResNet, VGG, and ViT models, and compare it against a range of sparsification methods.