🤖 AI Summary
Existing chain-of-thought distillation methods struggle to effectively transfer the teacher model’s dynamic attention to critical information during reasoning, thereby limiting the student model’s capacity for complex inference. This work proposes a novel distillation framework that introduces, for the first time, a stepwise attention mechanism to explicitly transfer the teacher’s dynamic attention trajectory to the student model. Furthermore, a hybrid-layer module is designed to enable cross-model dynamic alignment across network depths, guiding the student to progressively focus on salient reasoning cues. The proposed approach significantly enhances the performance of compact models across multiple mathematical and commonsense reasoning benchmarks, demonstrating its effectiveness in strengthening complex reasoning capabilities.
📝 Abstract
The significant computational demands of large language models have increased interest in distilling reasoning abilities into smaller models via Chain-of-Thought (CoT) distillation. Current CoT distillation methods mainly focus on transferring teacher-generated rationales for complex reasoning to student models. However, they do not adequately explore teachers' dynamic attention toward critical information during reasoning. We find that language models exhibit progressive attention shifts towards key information during reasoning, which implies essential clues for drawing conclusions. Building on this observation and analysis, we introduce a novel CoT distillation framework that transfers the teacher's stepwise attention on key information to the student model. This establishes structured guidance for the student's progressive concentration on key information during reasoning. More importantly, we develop a Mixture of Layers module enabling dynamic alignment that adapts to different layers between the teacher and student. Our method achieves consistent performance improvements across multiple mathematical and commonsense reasoning datasets. To our knowledge, it is the first method to leverage stepwise attention within CoT distillation to improve small model reasoning.