🤖 AI Summary
To address gradient dilution and inference latency in small language models (SLMs) during chain-of-thought (CoT) distillation—caused by long reasoning chains—this paper introduces a novel block-wise training and skip-thinking inference paradigm. We propose the first semantic-aware chunk-wise training mechanism, which partitions reasoning chains into semantically coherent chunks via heuristic search to concentrate gradient updates on critical reasoning tokens. Additionally, we design a skip-thinking architecture enabling SLMs to dynamically bypass redundant intermediate chunks, balancing accuracy and efficiency. Integrated with chunk-level supervised distillation and multi-task joint optimization, our approach significantly improves reasoning performance on models including Phi-3 and Qwen2-1.5B: accuracy increases by up to 8.2% over baselines, while average inference latency decreases by 37%.
📝 Abstract
Chain-of-thought (CoT) distillation allows a large language model (LLM) to guide a small language model (SLM) in reasoning tasks. Existing methods train the SLM to learn the long rationale in one iteration, resulting in two issues: 1) Long rationales lead to a large token-level batch size during training, making gradients of core reasoning tokens (i.e., the token will directly affect the correctness of subsequent reasoning) over-smoothed as they contribute a tiny fraction of the rationale. As a result, the SLM converges to sharp minima where it fails to grasp the reasoning logic. 2) The response is slow, as the SLM must generate a long rationale before reaching the answer. Therefore, we propose chunk-wise training (CWT), which uses a heuristic search to divide the rationale into internal semantically coherent chunks and focuses SLM on learning from only one chunk per iteration. In this way, CWT naturally isolates non-reasoning chunks that do not involve the core reasoning token (e.g., summary and transitional chunks) from the SLM learning for reasoning chunks, making the fraction of the core reasoning token increase in the corresponding iteration. Based on CWT, skip-thinking training (STT) is proposed. STT makes the SLM automatically skip non-reasoning medium chunks to reach the answer, improving reasoning speed while maintaining accuracy. We validate our approach on a variety of SLMs and multiple reasoning tasks.