🤖 AI Summary
This work addresses the “grokking” phenomenon—characterized by prolonged generalization stagnation followed by abrupt performance improvement during training. We identify imbalanced gradient descent updates across principal directions as the fundamental cause of stagnation. To mitigate this, we propose Equilibrated Gradient Descent (EGD), which leverages singular value decomposition to characterize the directional structure of gradients and applies a simplified correction to natural gradients to normalize step sizes across principal directions. EGD introduces no additional hyperparameters and significantly shortens—or even eliminates—generalization stagnation on tasks including modular addition and sparse parity. Empirically, it accelerates and stabilizes grokking. Our core contribution is establishing a causal link between gradient direction dynamics and grokking’s temporal evolution, and providing the first provably convergent optimization framework for grokking acceleration based on directional equilibration.
📝 Abstract
Grokking is the phenomenon whereby, unlike the training performance, which peaks early in the training process, the test/generalization performance of a model stagnates over arbitrarily many epochs and then suddenly jumps to usually close to perfect levels. In practice, it is desirable to reduce the length of such plateaus, that is to make the learning process "grok" faster. In this work, we provide new insights into grokking. First, we show both empirically and theoretically that grokking can be induced by asymmetric speeds of (stochastic) gradient descent, along different principal (i.e singular directions) of the gradients. We then propose a simple modification that normalizes the gradients so that dynamics along all the principal directions evolves at exactly the same speed. Then, we establish that this modified method, which we call egalitarian gradient descent (EGD) and can be seen as a carefully modified form of natural gradient descent, groks much faster. In fact, in some cases the stagnation is completely removed. Finally, we empirically show that on classical arithmetic problems such as modular addition and sparse parity problem which this stagnation has been widely observed and intensively studied, that our proposed method eliminates the plateaus.