🤖 AI Summary
This work addresses the generalization challenges in learning large-scale modular addition, which arise from input sensitivity and distributional shifts between training and testing. To mitigate these issues, the authors propose a training method that eliminates covariate shift by introducing an auxiliary modulus $Kq$, thereby reducing wraparound frequency while preserving input distribution consistency and effectively simplifying the learning task. Unlike existing sparsification strategies that induce distributional mismatches, this approach significantly enhances generalization performance. Experimental results demonstrate that under settings with $N=64$ and $q=974269$, the method achieves a τ-accuracy of 97.0% (with τ=0.05) using only 100,000 training samples—substantially outperforming sparse approaches, which attain merely 93.9% accuracy even with one million samples.
📝 Abstract
Learning parity functions, more general modular addition, is a challenging machine learning task due to its input sensitivity. A recent study substantially scaled modular addition learning in both the number of summands and the modulus. Its key idea is to increase zeros in training sequences, reducing the effective number of summands and thus controlling training difficulty; however, this induces covariate shift between training and test input distributions. This study theoretically and empirically analyzes this side effect and proposes a covariate-shift-free method for modular addition. Specifically, we introduce an auxiliary modulus $Kq$ during training, which reduces wrap-around frequency and problem difficulty while preserving the same input distribution across training and testing. Experiments show strong scalability and sample efficiency: even for large input length $N$, large modulus $q$, and small datasets -- where the sparse method fails to learn -- our method achieves equal or better match accuracy and relaxed $τ$-accuracy. For example, at $N=64$ and $q=974269$, our method trained on 100K samples achieves $97.0\%$ $τ$-accuracy at $τ=0.05$, while the sparse method achieves only $9.5\%$ with the same data size and $93.9\%$ even when extended to 1M samples.