CAGE: Curvature-Aware Gradient Estimation For Accurate Quantization-Aware Training

📅 2025-10-21
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
In low-bit quantization-aware training (QAT), gradient bias introduced by the straight-through estimator (STE) causes substantial accuracy degradation. To address this, we propose a curvature-aware gradient correction method grounded in multi-objective optimization and Pareto optimality, designing a theoretically guaranteed correction term that improves optimization trajectory quality under smooth non-convex conditions. By integrating local curvature information via Adam’s statistical moments, we construct an optimizer-agnostic, framework-compatible correction mechanism. Evaluated on W4A4 pretraining of an 800M-parameter Llama-style model, our method recovers over 10% of the quantization loss compared to state-of-the-art outlier-mitigation approaches, significantly narrowing the performance gap with full-precision training.

Technology Category

Application Category

📝 Abstract
Despite significant work on low-bit quantization-aware training (QAT), there is still a large accuracy gap between such techniques and native training. To address this, we introduce CAGE (Curvature-Aware Gradient Estimation), a new QAT method that augments the straight-through estimator (STE) gradient with a curvature-aware correction designed to counteract the loss increase induced by quantization. CAGE is derived from a multi-objective view of QAT that balances loss minimization with adherence to quantization constraints, yielding a principled correction term that depends on local curvature information. On the theoretical side, we introduce the notion of Pareto-optimal solutions for quantized optimization, and establish that CAGE yields strong convergence guarantees in the smooth non-convex setting. In terms of implementation, our approach is optimizer-agnostic, but we provide a highly-efficient implementation that leverages Adam statistics. When pre-training Llama-style models of up to 800M-parameters, CAGE recovers over 10% of the quantization-induced loss increase in the W4A4 regime over outlier-mitigation methods. These results indicate that curvature-aware gradient corrections can bridge the remaining performance gap beyond current outlier-handling methods.
Problem

Research questions and friction points this paper is trying to address.

Reduces accuracy gap in quantization-aware training
Counters loss increase from quantization with curvature correction
Improves convergence in non-convex optimization settings
Innovation

Methods, ideas, or system contributions that make the work stand out.

Uses curvature-aware gradient correction for quantization
Derives correction from multi-objective Pareto-optimal framework
Leverages Adam statistics for efficient optimizer-agnostic implementation
🔎 Similar Papers
No similar papers found.