Dual-Head Reasoning Distillation: Improving Classifier Accuracy with Train-Time-Only Reasoning

πŸ“… 2025-09-25
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
To address the trade-off where chain-of-thought (CoT) prompting improves classification accuracy but severely degrades inference throughput, this paper proposes dual-head reasoning distillation. During training, a dedicated reasoning head is introduced and supervised by teacher-generated reasoning chains, jointly optimizing label cross-entropy and sequence-level language modeling loss. At inference time, this head is disabled, retaining only a lightweight classification headβ€”incuring zero computational overhead. The core innovation lies in decoupling training and inference architectures, enabling the model to implicitly internalize reasoning knowledge without generating verbose rationales. Evaluated on seven SuperGLUE tasks, our method achieves absolute accuracy gains of 0.65–5.47% over baselines, with particularly pronounced improvements on causal and entailment tasks. Crucially, inference throughput reaches 96–142Γ— that of standard CoT prompting, effectively reconciling high accuracy with high efficiency.

Technology Category

Application Category

πŸ“ Abstract
Chain-of-Thought (CoT) prompting often improves classification accuracy, but it introduces a significant throughput penalty with rationale generation (Wei et al., 2022; Cheng and Van Durme, 2024). To resolve this trade-off, we introduce Dual-Head Reasoning Distillation (DHRD), a simple training method for decoder-only language models (LMs) that adds (i) a pooled classification head used during training and inference and (ii) a reasoning head supervised by teacher rationales used only in training. We train with a loss function that is a weighted sum of label cross-entropy and token-level LM loss over input-plus-rationale sequences. On seven SuperGLUE tasks, DHRD yields relative gains of 0.65-5.47% over pooled baselines, with notably larger gains on entailment/causal tasks. Since we disable the reasoning head at test time, inference throughput matches pooled classifiers and exceeds CoT decoding on the same backbones by 96-142 times in QPS.
Problem

Research questions and friction points this paper is trying to address.

Improving classifier accuracy without inference-time reasoning overhead
Eliminating throughput penalty from Chain-of-Thought rationale generation
Maintaining training benefits while matching baseline inference speed
Innovation

Methods, ideas, or system contributions that make the work stand out.

Dual-head distillation with train-time reasoning head
Weighted loss combining classification and rationale supervision
Test-time inference matches pooled classifier throughput