🤖 AI Summary
Transformer models suffer from low learning efficiency and weak generalization due to their lack of syntactic inductive bias. To address this, we propose TreeReg, a regularization method that converts silver-standard dependency or constituency trees into differentiable orthogonality constraints over hidden states—softly injecting hierarchical syntactic priors without altering model architecture or increasing inference overhead. TreeReg is fully compatible with standard pretraining and fine-tuning pipelines and supports multi-task loss integration. Experiments demonstrate that TreeReg reduces out-of-distribution perplexity by 10% on WikiText-103, improves syntactic generalization by 9.5 points, and matches or exceeds baseline performance using only 50% of the training data. In Sheared LLaMA continual pretraining, it significantly enhances syntactic robustness; in MultiNLI fine-tuning, it mitigates adversarial degradation by 41.2 points. These results confirm that structured syntactic regularization—without architectural modification—yields substantial gains in both efficiency and robustness.
📝 Abstract
While compositional accounts of human language understanding are based on a hierarchical tree-like process, neural models like transformers lack a direct inductive bias for such tree structures. Introducing syntactic inductive biases could unlock more robust and data-efficient learning in transformer language models (LMs), but existing methods for incorporating such structure greatly restrict models, either limiting their expressivity or increasing inference complexity. This work instead aims to softly inject syntactic inductive biases into given transformer circuits, through a structured regularizer. We introduce TreeReg, an auxiliary loss function that converts bracketing decisions from silver parses into a set of differentiable orthogonality constraints on vector hidden states. TreeReg integrates seamlessly with the standard LM objective, requiring no architectural changes. LMs pre-trained with TreeReg on natural language corpora such as WikiText-103 achieve up to 10% lower perplexities on out-of-distribution data and up to 9.5 point improvements in syntactic generalization, requiring less than half the training data to outperform standard LMs. TreeReg still provides gains for pre-trained LLMs: Continued pre-training of Sheared Llama with TreeReg results in improved syntactic generalization, and fine-tuning on MultiNLI with TreeReg mitigates degradation of performance on adversarial NLI benchmarks by 41.2 points. We release all code to guide future research.