🤖 AI Summary
Transformers face two key challenges in causal language modeling: inefficient KV cache memory utilization and the rigidity of static grouping strategies (e.g., Grouped-Query Attention, GQA), which cannot adapt to dynamic token importance. To address these, we propose *mixSGA*, the first token-level dynamic expert routing mechanism for attention. It employs learnable importance scores to enable fine-grained, full-token-retention computation and memory allocation. To ensure training–inference consistency, we introduce cross-expert weight sharing and an auxiliary one-hot consistency loss. Integrating Mixture-of-Experts (MoE) principles with a generalized GQA extension, mixSGA is evaluated on Llama3, TinyLlama, OPT, and Gemma2. Under identical KV cache memory budgets, it achieves significant improvements in ROUGE-L and reductions in perplexity, consistently outperforming static baselines—including GQA—across all models and metrics.
📝 Abstract
Transformer models face scalability challenges in causal language modeling (CLM) due to inefficient memory allocation for growing key-value (KV) caches, which strains compute and storage resources. Existing methods like Grouped Query Attention (GQA) and token-level KV optimization improve efficiency but rely on rigid resource allocation, often discarding"low-priority"tokens or statically grouping them, failing to address the dynamic spectrum of token importance. We propose mixSGA, a novel mixture-of-expert (MoE) approach that dynamically optimizes token-wise computation and memory allocation. Unlike prior approaches, mixSGA retains all tokens while adaptively routing them to specialized experts with varying KV group sizes, balancing granularity and efficiency. Our key novelties include: (1) a token-wise expert-choice routing mechanism guided by learned importance scores, enabling proportional resource allocation without token discard; (2) weight-sharing across grouped attention projections to minimize parameter overhead; and (3) an auxiliary loss to ensure one-hot routing decisions for training-inference consistency in CLMs. Extensive evaluations across Llama3, TinyLlama, OPT, and Gemma2 model families show mixSGA's superiority over static baselines. On instruction-following and continued pretraining tasks, mixSGA achieves higher ROUGE-L and lower perplexity under the same KV budgets.