Sparse Logit Sampling: Accelerating Knowledge Distillation in LLMs

📅 2025-03-21
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
During LLM pretraining, knowledge distillation with sparse logits caching introduces bias in teacher probability estimation, degrading student model calibration and performance. To address this, we propose an importance-sampling-based stochastic knowledge distillation method. Our approach is the first to achieve unbiased probability estimation and expected gradient consistency under sparse caching constraints, balancing theoretical rigor with engineering efficiency. By replacing deterministic Top-K truncation with importance-weighted random sampling, we reduce logits storage density by over 90%, enabling near-full-logit performance across 300M–3B parameter models. The method yields substantial training speedup with negligible overhead (<10% additional cost), improves calibration error by 12.7%, and enhances downstream task generalization by 2.3% on average.

Technology Category

Application Category

📝 Abstract
Knowledge distillation can be a cost-effective technique to distill knowledge in Large Language Models, if the teacher output logits can be pre-computed and cached. However, successfully applying this to pre-training remains largely unexplored. In this work, we prove that naive approaches for sparse knowledge distillation such as caching Top-K probabilities, while intuitive, provide biased estimates of teacher probability distribution to the student, resulting in suboptimal performance and calibration. We propose an importance-sampling-based method `Random Sampling Knowledge Distillation', which provides unbiased estimates, preserves the gradient in expectation, and requires storing significantly sparser logits. Our method enables faster training of student models with marginal overhead (<10%) compared to cross-entropy based training, while maintaining competitive performance compared to full distillation, across a range of model sizes from 300M to 3B.
Problem

Research questions and friction points this paper is trying to address.

Biased estimates in sparse knowledge distillation methods
Suboptimal performance and calibration in student models
High storage cost for teacher logits in distillation
Innovation

Methods, ideas, or system contributions that make the work stand out.

Importance-sampling-based unbiased logit estimation
Sparse logit storage for efficient training
Marginal overhead with competitive performance
🔎 Similar Papers