Direct Quantized Training of Language Models with Stochastic Rounding

๐Ÿ“… 2024-12-06
๐Ÿ›๏ธ arXiv.org
๐Ÿ“ˆ Citations: 1
โœจ Influential: 0
๐Ÿ“„ PDF
๐Ÿค– AI Summary
Existing LLM quantization-aware training (QAT) relies on high-precision weight copies to enable the Straight-Through Estimator (STE), incurring substantial memory overhead. Method: This work proposes a full low-bit quantization training paradigm that eliminates both STE and floating-point weight replicas, enabling backward propagation directly on ternary (1.58-bit), 2-bit, or 8-bit quantized weights. Contribution/Results: We present the first end-to-end ternary-weight training for large language models; introduce stochastic rounding to mitigate gradient distortion; and adapt the framework to the LLaMA architecture. Experiments show that our 8-bit model achieves only 0.5% lower performance than BitNet b1.58, while significantly reducing training memory consumption and enabling native inference on quantized weightsโ€”breaking the long-standing dependency of QAT on high-precision weight copies for parameter updates.

Technology Category

Application Category

๐Ÿ“ Abstract
Although recent quantized Large Language Models (LLMs), such as BitNet, have paved the way for significant reduction in memory usage during deployment with binary or ternary weights, training these models still demands substantial memory footprints. This is partly because high-precision (i.e., unquantized) weight matrices required for straight-through estimation must be maintained throughout the whole training process. To address this, we explore the potential of directly updating the quantized low-precision weight matrices without relying on the straight-through estimator during backpropagation, thereby saving memory usage during training. Specifically, we employ a stochastic rounding technique to minimize information loss caused by the use of low-bit weights throughout training. Experimental results on our LLaMA-structured models indicate that (1) training with only low-precision weights is feasible even when they are constrained to ternary values, (2) extending the bit width to 8 bits results in only a 5% loss degradation compared to BitNet b1.58 while offering the potential for reduced memory usage during training, and (3) our models can also perform inference using ternary weights, showcasing their flexibility in deployment.
Problem

Research questions and friction points this paper is trying to address.

Reduce memory usage in training quantized LLMs
Enable direct updating of low-precision weights
Minimize information loss with stochastic rounding
Innovation

Methods, ideas, or system contributions that make the work stand out.

Directly updating quantized low-precision weights
Using stochastic rounding to minimize information loss
Supporting inference with ternary weights
๐Ÿ”Ž Similar Papers
No similar papers found.
Kaiyan Zhao
Kaiyan Zhao
The University of Tokyo
Natural Language Processing
T
T. Tabaru
Fujitsu Limited, Kawasaki, Japan
K
Kenichi Kobayashi
Fujitsu Limited, Kawasaki, Japan
T
Takumi Honda
Fujitsu Limited, Kawasaki, Japan
M
Masafumi Yamazaki
Fujitsu Limited, Kawasaki, Japan
Yoshimasa Tsuruoka
Yoshimasa Tsuruoka
The University of Tokyo
Natural Language ProcessingReinforcement LearningArtificial Intelligence for Games