๐ค AI Summary
Existing LLM quantization-aware training (QAT) relies on high-precision weight copies to enable the Straight-Through Estimator (STE), incurring substantial memory overhead. Method: This work proposes a full low-bit quantization training paradigm that eliminates both STE and floating-point weight replicas, enabling backward propagation directly on ternary (1.58-bit), 2-bit, or 8-bit quantized weights. Contribution/Results: We present the first end-to-end ternary-weight training for large language models; introduce stochastic rounding to mitigate gradient distortion; and adapt the framework to the LLaMA architecture. Experiments show that our 8-bit model achieves only 0.5% lower performance than BitNet b1.58, while significantly reducing training memory consumption and enabling native inference on quantized weightsโbreaking the long-standing dependency of QAT on high-precision weight copies for parameter updates.
๐ Abstract
Although recent quantized Large Language Models (LLMs), such as BitNet, have paved the way for significant reduction in memory usage during deployment with binary or ternary weights, training these models still demands substantial memory footprints. This is partly because high-precision (i.e., unquantized) weight matrices required for straight-through estimation must be maintained throughout the whole training process. To address this, we explore the potential of directly updating the quantized low-precision weight matrices without relying on the straight-through estimator during backpropagation, thereby saving memory usage during training. Specifically, we employ a stochastic rounding technique to minimize information loss caused by the use of low-bit weights throughout training. Experimental results on our LLaMA-structured models indicate that (1) training with only low-precision weights is feasible even when they are constrained to ternary values, (2) extending the bit width to 8 bits results in only a 5% loss degradation compared to BitNet b1.58 while offering the potential for reduced memory usage during training, and (3) our models can also perform inference using ternary weights, showcasing their flexibility in deployment.