🤖 AI Summary
This work addresses the degradation of shallow-layer information in deeply scaled large language models, where repeated residual updates impede effective recovery in deeper layers. To mitigate this, we propose the Mixture-of-Depth Attention (MoDA) mechanism, which introduces cross-layer key-value fusion for the first time, enabling each attention head to jointly attend to key-value pairs from both the current and preceding layers. MoDA is integrated with a post-normalization architecture to enhance representational capacity. Furthermore, we design a memory-efficient algorithm compatible with FlashAttention-2, substantially reducing non-contiguous memory overhead. Experiments on a 1.5B-parameter model show that MoDA reduces average perplexity by 0.2 and improves downstream task performance by 2.11%, with only a 3.7% increase in FLOPs and achieving 97.3% of FlashAttention-2’s computational efficiency, offering a new paradigm for deep model scaling.
📝 Abstract
Scaling depth is a key driver for large language models (LLMs). Yet, as LLMs become deeper, they often suffer from signal degradation: informative features formed in shallow layers are gradually diluted by repeated residual updates, making them harder to recover in deeper layers. We introduce mixture-of-depths attention (MoDA), a mechanism that allows each attention head to attend to sequence KV pairs at the current layer and depth KV pairs from preceding layers. We further describe a hardware-efficient algorithm for MoDA that resolves non-contiguous memory-access patterns, achieving 97.3% of FlashAttention-2's efficiency at a sequence length of 64K. Experiments on 1.5B-parameter models demonstrate that MoDA consistently outperforms strong baselines. Notably, it improves average perplexity by 0.2 across 10 validation benchmarks and increases average performance by 2.11% on 10 downstream tasks, with a negligible 3.7% FLOPs computational overhead. We also find that combining MoDA with post-norm yields better performance than using it with pre-norm. These results suggest that MoDA is a promising primitive for depth scaling. Code is released at https://github.com/hustvl/MoDA .