🤖 AI Summary
This work addresses the challenge in large language model pretraining where token embeddings often exhibit high intra-class variance and strong inter-class similarity due to contextual dependencies, thereby limiting representation learning efficiency. For the first time, the study effectively integrates embedding similarity regularization into large-scale language model pretraining by leveraging a contrastive learning mechanism that enhances the clustering of embeddings sharing the same label while enlarging the separation between embeddings of different labels, thus widening the decision margin in multi-class settings. The proposed approach consistently improves both training efficiency and representation quality across dense and mixture-of-experts (MoE) architectures, accelerating convergence by over 30% and boosting average zero-shot performance by more than 1% on standard benchmarks.
📝 Abstract
Pretraining large language models (LLMs) with next-token prediction has led to remarkable advances, yet the context-dependent nature of token embeddings in such models results in high intra-class variance and inter-class similarity, thus hindering the efficiency of representation learning. While similarity-based regularization has demonstrated benefit in supervised fine-tuning and classification tasks, its application and efficacy in large-scale LLM pretraining remains underexplored. In this work, we propose the SimReg, an embedding similarity regularization loss that explicitly encourages token representations with the same ground-truth label within each sequence to be more similar, while enforcing separation from different-label tokens via a contrastive loss. Our analysis reveals that this mechanism introduces gains by enlarging multi-classification margins, thereby enabling more efficient classification. Extensive experiments across dense and Mixture-of-Experts (MoE) architectures demonstrate that SimReg consistently accelerates training convergence by over 30% and improves average zero-shot downstream performance by over 1% across standard benchmarks. Further ablation studies and analyses offer practical insights into hyperparameter tuning and loss effectiveness.