🤖 AI Summary
This work addresses the computational and memory bottlenecks of large language models in ultra-long-context scenarios, where standard Transformers incur prohibitive costs and existing sparse or linear attention mechanisms struggle to balance efficiency with performance. The authors propose MiniCPM-SALA, a 9B-parameter hybrid attention architecture that integrates sparse attention (InfLLM-V2) and linear attention (Lightning Attention) in a 1:3 ratio, augmented with Hybrid Positional Encoding (HyPE) and a layer selection algorithm. Coupled with a low-cost continual pretraining framework, this approach reduces training costs by approximately 75%. On a single NVIDIA A6000D GPU, MiniCPM-SALA achieves 3.5× faster inference than full attention at 256K context length, supports sequences up to 1M tokens, and maintains competitive general capabilities while substantially alleviating memory constraints.
📝 Abstract
The evolution of large language models (LLMs) towards applications with ultra-long contexts faces challenges posed by the high computational and memory costs of the Transformer architecture. While existing sparse and linear attention mechanisms attempt to mitigate these issues, they typically involve a trade-off between memory efficiency and model performance. This paper introduces MiniCPM-SALA, a 9B-parameter hybrid architecture that integrates the high-fidelity long-context modeling of sparse attention (InfLLM-V2) with the global efficiency of linear attention (Lightning Attention). By employing a layer selection algorithm to integrate these mechanisms in a 1:3 ratio and utilizing a hybrid positional encoding (HyPE), the model maintains efficiency and performance for long-context tasks. Furthermore, we introduce a cost-effective continual training framework that transforms pre-trained Transformer-based models into hybrid models, which reduces training costs by approximately 75% compared to training from scratch. Extensive experiments show that MiniCPM-SALA maintains general capabilities comparable to full-attention models while offering improved efficiency. On a single NVIDIA A6000D GPU, the model achieves up to 3.5x the inference speed of the full-attention model at the sequence length of 256K tokens and supports context lengths of up to 1M tokens, a scale where traditional full-attention 8B models fail because of memory constraints.