🤖 AI Summary
Existing N:M sparse masking schemes lack transpose invariance, leading to inefficient training; mainstream approaches are constrained to M=4 or struggle to scale to large models, compromising the trade-off between compression ratio and accuracy. This work introduces the first scalable, transpose-invariant solver for arbitrary N:M sparsity patterns, lifting the M=4 limitation and enabling application to billion-parameter models. Our method formulates mask optimization as an entropy-regularized optimal transport problem, solved via Dykstra’s alternating projection algorithm with tensorized GPU parallelization. Experiments demonstrate a 100× speedup over state-of-the-art solvers while achieving only 1–10% reconstruction error. On LLaMA3.2-8B, our 16:32 transposed-sparse configuration matches standard N:M performance and substantially outperforms conventional 2:4 baselines—establishing a new frontier in structured sparsity for efficient LLM training and inference.
📝 Abstract
Network pruning reduces the computational requirements of large neural networks, with N:M sparsity -- retaining only N out of every M consecutive weights -- offering a compelling balance between compressed model quality and hardware acceleration. However, N:M sparsity only accelerates forward-pass computations, as N:M patterns are not preserved during matrix transposition, limiting efficiency during training where both passes are computationally intensive. While transposable N:M sparsity has been proposed to address this limitation, existing methods for finding transposable N:M sparse masks either fail to scale to large models or are restricted to M=4 which results in suboptimal compression-accuracy trade-off. We introduce an efficient solver for transposable N:M masks that scales to billion-parameter models. We formulate mask generation as optimal transport problems and solve through entropy regularization and Dykstra's algorithm, followed by a rounding procedure. Our tensor-based implementation exploits GPU parallelism, achieving up to 100x speedup with only 1-10% error compared to existing methods. Our approach can be integrated with layer-wise N:M pruning frameworks including Wanda, SparseGPT and ALPS to produce transposable N:M sparse models with arbitrary N:M values. Experiments show that LLaMA3.2-8B with transposable 16:32 sparsity maintains performance close to its standard N:M counterpart and outperforms standard 2:4 sparse model, showing the practical value of our approach.