STEM: Scaling Transformers with Embedding Modules

📅 2026-01-15
📈 Citations: 3
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the challenges of instability, load imbalance, and communication overhead in fine-grained sparse training by proposing a static, token-index-based sparsity method. Specifically, it replaces the up-projection layer in the Transformer feedforward network (FFN) with a local embedding lookup while retaining the gating and down-projection layers as dense components. This design eliminates the need for dynamic routing, enables asynchronous CPU prefetching, decouples parameter capacity from per-token computation and communication costs, and permits direct knowledge editing in the embedding space. Evaluated on 350M and 1B models, the approach achieves accuracy gains of 3–4% on benchmarks including ARC-Challenge, OpenBookQA, GSM8K, and MMLU, while reducing FFN parameter accesses and per-token FLOPs by approximately one-third.

Technology Category

Application Category

📝 Abstract
Fine-grained sparsity promises higher parametric capacity without proportional per-token compute, but often suffers from training instability, load balancing, and communication overhead. We introduce STEM (Scaling Transformers with Embedding Modules), a static, token-indexed approach that replaces the FFN up-projection with a layer-local embedding lookup while keeping the gate and down-projection dense. This removes runtime routing, enables CPU offload with asynchronous prefetch, and decouples capacity from both per-token FLOPs and cross-device communication. Empirically, STEM trains stably despite extreme sparsity. It improves downstream performance over dense baselines while reducing per-token FLOPs and parameter accesses (eliminating roughly one-third of FFN parameters). STEM learns embedding spaces with large angular spread which enhances its knowledge storage capacity. More interestingly, this enhanced knowledge capacity comes with better interpretability. The token-indexed nature of STEM embeddings allows simple ways to perform knowledge editing and knowledge injection in an interpretable manner without any intervention in the input text or additional computation. In addition, STEM strengthens long-context performance: as sequence length grows, more distinct parameters are activated, yielding practical test-time capacity scaling. Across 350M and 1B model scales, STEM delivers up to ~3--4% accuracy improvements overall, with notable gains on knowledge and reasoning-heavy benchmarks (ARC-Challenge, OpenBookQA, GSM8K, MMLU). Overall, STEM is an effective way of scaling parametric memory while providing better interpretability, better training stability and improved efficiency.
Problem

Research questions and friction points this paper is trying to address.

sparsity
parametric capacity
training stability
knowledge storage
efficient scaling
Innovation

Methods, ideas, or system contributions that make the work stand out.

sparsity
embedding modules
parametric capacity
interpretability
efficient transformers
🔎 Similar Papers
No similar papers found.