🤖 AI Summary
This work addresses the challenges of unstable convergence and high training costs in ultra-low-bit quantization-aware training (QAT), primarily caused by outlier channels and inter-layer error accumulation. The authors propose a progressive QAT framework that integrates block-wise progressive training, a nested integer quantization grid enabling “train-once, deploy-at-any-precision” flexibility, and a rounding-aware outlier channel splitting mechanism formulated as an identity-preserving transformation to effectively suppress quantization error. Leveraging E4M3 microscaling groups and custom W2A2/W2A16 operators, the method achieves a WikiText2 perplexity of only 2.25 under W2A2 configuration on Llama-2/3 models—approaching full-precision performance—and demonstrates up to 11× speedup over BF16 inference.
📝 Abstract
Training LLMs at ultra-low precision remains a formidable challenge. Direct low-bit QAT often suffers from convergence instability and substantial training costs, exacerbated by quantization noise from heavy-tailed outlier channels and error accumulation across layers. To address these issues, we present Bit-by-Bit, a progressive QAT framework with outlier channel splitting. Our approach integrates three key components: (1) block-wise progressive training that reduces precision stage by stage, ensuring stable initialization for low-bit optimization; (2) nested structure of integer quantization grids to enable a "train once, deploy any precision" paradigm, allowing a single model to support multiple bit-widths without retraining; (3) rounding-aware outlier channel splitting, which mitigates quantization error while acting as an identity transform that preserves the quantized outputs. Furthermore, we follow microscaling groups with E4M3 scales, capturing dynamic activation ranges in alignment with OCP/NVIDIA standards. To address the lack of efficient 2-bit kernels, we developed custom operators for both W2A2 and W2A16 configurations, achieving up to 11$\times$ speedup over BF16. Under W2A2 settings, Bit-by-Bit significantly outperforms baselines like BitDistiller and EfficientQAT on both Llama2/3, achieving a loss of only 2.25 WikiText2 PPL compared to full-precision models.