🤖 AI Summary
Greedy heuristic-based pruning methods for large language models (LLMs) neglect inter-weight dependencies, leading to accumulated errors and suboptimal sparsity-accuracy trade-offs.
Method: This work formulates structured pruning as a combinatorial optimization problem and introduces, for the first time, a convex relaxation to convert it into a differentiable optimization task. The relaxed problem is solved via the Frank–Wolfe algorithm, yielding theoretically guaranteed near-optimal pruning masks without full retraining—only a small calibration dataset is required for efficient, layer-wise collaborative pruning.
Contribution/Results: Evaluated on GPT-series models, our method significantly mitigates accuracy degradation compared to state-of-the-art baselines, achieving an average +2.1% improvement in task accuracy under identical sparsity constraints. It simultaneously attains higher sparsity and improved inference efficiency, thereby overcoming the fundamental limitations of layer-wise greedy pruning.
📝 Abstract
Pruning is a common technique to reduce the compute and storage requirements of Neural Networks. While conventional approaches typically retrain the model to recover pruning-induced performance degradation, state-of-the-art Large Language Model (LLM) pruning methods operate layer-wise, minimizing the per-layer pruning error on a small calibration dataset to avoid full retraining, which is considered computationally prohibitive for LLMs. However, finding the optimal pruning mask is a hard combinatorial problem and solving it to optimality is intractable. Existing methods hence rely on greedy heuristics that ignore the weight interactions in the pruning objective. In this work, we instead consider the convex relaxation of these combinatorial constraints and solve the resulting problem using the Frank-Wolfe (FW) algorithm. Our method drastically reduces the per-layer pruning error, outperforms strong baselines on state-of-the-art GPT architectures, and remains memory-efficient. We provide theoretical justification by showing that, combined with the convergence guarantees of the FW algorithm, we obtain an approximate solution to the original combinatorial problem upon rounding the relaxed solution to integrality.