Beyond Surrogate Gradients: Fully Differentiable Token Pruning for Vision-Language Models

📅 2026-05-27
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Existing visual-language models rely on Gumbel-Softmax to approximate discrete token selection for vision token pruning, but the resulting surrogate gradients lead to unreliable importance estimation. This work proposes DiffPrune, which reframes pruning as a continuously differentiable control over token information. By introducing a variance-preserving noise through an information throttler and modulating token representations with learned importance scores, DiffPrune enables end-to-end differentiable training. During inference, hard-threshold pruning is applied based on these learned scores, eliminating the need for discrete sampling or surrogate gradients. Evaluated across ten visual-language benchmarks, DiffPrune achieves 96.5% of the full model’s accuracy while accelerating LLM prefilling by 2.85×, with only a 0.69 ms increase in inference latency.
📝 Abstract
Visual token pruning reduces the computational cost of Vision-Language Models (VLMs) by removing redundant visual tokens. Existing methods typically rely on Gumbel-Softmax to approximate discrete selection during training. However, the optimization is driven by surrogate gradients rather than the true selection process, leading to unreliable learning of token importance. In this paper, we propose DiffPrune, which reformulates pruning as continuous control of token information instead of discrete selection learning. Specifically, we introduce an Information Throttler that modulates each token using variance-preserving noise conditioned on importance scores, where higher scores induce less information suppression during training. This design directly operates on token representations, naturally providing a fully differentiable optimization path for learning token importance. At inference, tokens are removed via hard thresholding on the learned scores. Across ten VLM benchmarks, DiffPrune retains 96.5% of full-model accuracy while accelerating LLM prefill by 2.85x, with only 0.69 ms of inference overhead.
Problem

Research questions and friction points this paper is trying to address.

token pruning
vision-language models
surrogate gradients
discrete selection
computational efficiency
Innovation

Methods, ideas, or system contributions that make the work stand out.

fully differentiable pruning
token importance learning
information throttler
vision-language models
continuous token control