🤖 AI Summary
Conventional deep learning optimizers rely solely on minibatch gradient means, discarding rich per-sample gradient statistics—primarily due to prohibitive computational overhead of higher-order statistics under standard automatic differentiation (AD) frameworks.
Method: We propose an efficient method for computing per-sample gradient statistics by restructuring the AD computation graph and leveraging JAX’s vectorized transformations, enabling near-zero-overhead estimation of variance, sign distributions, and other statistics.
Contribution/Results: Our approach reveals the critical impact of sign-operation placement on convergence in signSGD and empirically refutes the classical assumption that variance dominates preconditioning in Adam. Experiments on Transformer models demonstrate the practical viability and performance gains of statistics-aware optimizers. This work provides a new lens for analyzing nonlinear optimization dynamics and establishes a novel paradigm for optimizer design grounded in fine-grained gradient statistics.
📝 Abstract
Training algorithms in deep learning usually treat a mini-batch of samples as a single object; they average gradients over the mini-batch, and then process the average in various ways. Computing other statistics beyond the average may have been seen as prohibitively resource intensive in automatic differentiation (AD) frameworks. We show that this is not the case. Generally, gradient statistics can be implemented through a surgery of the AD graph, which, in some cases, incur almost no computational and memory overheads compared to the mini-batch gradient computation. Additionally, we show that in certain classes of models, including transformers, JAX's vectorization transformation offers a viable implementation for prototyping and experimentation. We then revise our understanding of two nonlinear operations in optimization through the lens of per-example gradient transformations. We first study signSGD and show that the optimal placement of the sign operation in the gradient processing chain is crucial to success and can be predicted with a simple signal-to-noise ratio argument. Next we study per-example variations of the Adam preconditioner, and show that optimization is best served when the preconditioner is dominated by the mean rather than the variance of the gradient distribution - in contrast to conventional wisdom. Overall we demonstrate that per-example gradient information enables new analyses and possibilities for algorithm design.