🤖 AI Summary
Existing knowledge distillation methods predominantly employ static logit-matching strategies—applying uniform token weighting and fixed temperature scaling—which fail to align with the dynamic learning process of student models. This work proposes a token-adaptive knowledge distillation framework tailored for large language models (LLMs). Our approach addresses the problem by: (1) introducing a loss-driven adaptive focusing module, grounded in a unified token difficulty metric, to enable fine-grained gradient reweighting; and (2) proposing an inverse difficulty-based temperature scaling strategy that assigns lower temperatures to harder tokens, thereby strengthening supervision signals. The framework is plug-and-play compatible with standard logit distillation pipelines and supports diverse LLM architectures and downstream tasks. Extensive experiments across multiple LLM backbones and benchmarks demonstrate consistent and significant improvements over static baselines, validating its effectiveness, generality, and cross-task generalization capability.
📝 Abstract
Knowledge distillation (KD) is a key technique for compressing large-scale language models (LLMs), yet prevailing logit-based methods typically employ static strategies that are misaligned with the dynamic learning process of student models. These methods typically treat all tokens indiscriminately and apply a single, fixed temperature, resulting in suboptimal knowledge transfer. To address these limitations, we propose LLM-Oriented Token-Adaptive Knowledge Distillation (AdaKD), a novel framework that adapts the distillation process to the real-time learning state of each token. AdaKD consists of two synergistic modules driven by a unified token difficulty metric. First, our Loss-Driven Adaptive Token Focusing (LATF) module dynamically adjusts the distillation focus by monitoring the student's learning stability, concentrating computational resources on the most valuable tokens at each training phase. Second, we introduce Inverse Difficulty Temperature Scaling (IDTS), a counterintuitive yet effective token-level temperature strategy. It employs low temperatures for difficult tokens for targeted error correction, and high temperatures for easy tokens to encourage students to learn from the teacher's complete and smooth output distribution, thereby enhancing generalization. As a plug-and-play framework, AdaKD can consistently improve the performance of various distillation methods on multiple model architectures and benchmarks.