🤖 AI Summary
To address the tension between quantization compression and accuracy preservation in large language model (LLM) training, this paper proposes the first quantization-aware training (QAT) framework enabling stable end-to-end training with 1-bit weights and activations. Methodologically, it introduces three key innovations: (1) Hadamard-transform-based normalization to mitigate gradient instability under ultra-low-bit quantization; (2) an MSE-optimal quantization strategy to enhance fidelity of low-precision representations; and (3) a trust-region gradient estimator that suppresses optimization-direction distortion caused by quantization noise. Evaluated on Llama models, the framework achieves FP16-level accuracy with full 4-bit quantization and—critically—enables the first convergent 1-bit training, surpassing baseline accuracy. Model size is reduced by 16×, facilitating native hardware deployment. The work releases efficient GPU kernels and fully reproducible code.
📝 Abstract
One approach to reducing the massive costs of large language models (LLMs) is the use of quantized or sparse representations for training or deployment. While post-training compression methods are very popular, the question of obtaining even more accurate compressed models by directly training over such representations, i.e., Quantization-Aware Training (QAT), is still open: for example, a recent study (arXiv:2411.04330v2) put the"optimal"bit-width at which models can be trained using QAT, while staying accuracy-competitive with standard FP16/BF16 precision, at 8-bits weights and activations. We advance this state-of-the-art via a new method called QuEST, which is Pareto-competitive with FP16, i.e., it provides better accuracy at lower model size, while training models with weights and activations in 4-bits or less. Moreover, QuEST allows stable training with 1-bit weights and activations. QuEST achieves this by improving two key aspects of QAT methods: (1) accurate and fast quantization of the (continuous) distributions of weights and activations via Hadamard normalization and MSE-optimal fitting; (2) a new trust gradient estimator based on the idea of explicitly minimizing the error between the noisy gradient computed over quantized states and the"true"(but unknown) full-precision gradient. Experiments on Llama-type architectures show that QuEST induces stable scaling laws across the entire range of hardware-supported precisions, and can be extended to sparse representations. We provide GPU kernel support showing that models produced by QuEST can be executed efficiently. Our code is available at https://github.com/IST-DASLab/QuEST.