🤖 AI Summary
This work addresses the challenge of weak and noisy gradients in gating networks within conditional deep routing, where gradients must propagate through multiple layers. Under a fixed computational budget, the study systematically compares MLP-based gates with JEPA-guided gates and investigates the impact of auxiliary losses—such as utility regression and ranking supervision—on controller training. Contrary to conventional wisdom, the key finding is that removing these auxiliary losses consistently improves performance, revealing a strategic mismatch between such losses and sparse execution policies. The proposed approach reduces training FLOPs proxy overhead by 39%, accelerates convergence, and achieves superior language modeling loss. Notably, JEPA-based gates demonstrate faster early-stage optimization compared to MLP gates under standard settings.
📝 Abstract
Conditional depth execution routes a subset of tokens through a lightweight cheap FFN while the remainder execute the standard full FFN at each controlled layer. The central difficulty is gate training: the gate decision must propagate through many layers before it influences the language modeling (LM) loss, so the resulting gradients are weak and noisy. Auxiliary losses are commonly stacked to stabilise training, yet the interactions among them -- particularly between a predictive auxiliary and explicit score supervision -- have not been systematically compared under controlled conditions.
We evaluate two gate designs under a 157.5M-parameter decoder-only model with controller-only training, 50% full-path budget, and 3-seed runs on a fineweb-edu subset. The MLP gate (G1) maps the current hidden state to a utility score; the JEPA-guided gate (G3) adds an action-conditional predictor that forecasts, in a low-dimensional latent space, the outcome of executing full vs. cheap per token, aligned against a fixed target head. Under the standard recipe with oracle-style utility regression and pairwise rank supervision (util/rank), G3 improves early-to-mid optimisation over G1 in 3/3 seeds (lower avg LM, faster threshold hits, ~10.3x lower grad norms), with 20k-step endpoint LM within a 0.005 heuristic reference.
A key finding (ablation A3): jointly removing util/rank improves best/avg LM and threshold-hit speed in 3/3 seeds for both gates, and the early-to-mid advantage of G3 over G1 disappears. We trace this to an off-policy oracle label that assumes all subsequent layers execute full, whereas gated execution routes only a fraction through full -- making util/rank net-negative under the current recipe. Removing util/rank also cuts the training FLOPs proxy from ~1.53x to ~1.07x full-only (2.87h to 1.75h on a V100-32GB, ~39%). Conclusions are scoped to the studied regime.