🤖 AI Summary
This work addresses the limitations of ultra-low-bit quantization in Transformers, which often suffers from accuracy degradation and inadequate hardware support. The authors propose a binary-weight ternary-activation (BWTA) quantization scheme, integrating zero-point distortion analysis, magnitude-aligned projection scaling, and instruction-level parallel bit-packing. They further design custom CUDA kernels and a multi-stage annealing training strategy to enable algorithm-hardware co-optimization. On BERT, the approach incurs only a 3.5% average drop on GLUE (with five tasks showing less than 2% degradation), while large language models retain near full-precision perplexity. The inference kernels achieve 16–24× speedup over FP16, delivering end-to-end prefill throughput of 216–330 tokens/s and substantially reduced memory footprint.
📝 Abstract
Ultra low-bit quantization brings substantial efficiency for Transformer-based models, but the accuracy degradation and limited GPU support hinder its wide usage. In this paper, we analyze zero-point distortion in binarization and propose a Binary Weights & Ternary Activations (BWTA) quantization scheme, which projects tiny values to zero and preserves the accuracy of extremely low-bit models. For training, we propose Smooth Multi-Stage Quantization, combining a Levelwise Degradation Strategy and a Magnitude-Alignment Projection Factor to enable stable and fast convergence. For inference, we develop a BWTA MatMul CUDA kernel with instruction-level parallel bit-packing and comprehensive binary/ternary MatMul implementations for both linear and attention operators, allowing seamless integration across Transformer architectures. Experiments show that BWTA approaches full-precision performance for BERT, with an average 3.5% drop on GLUE and less than 2% drop on five tasks, and achieves comparable perplexity and accuracy for LLMs. In efficiency, it delivers 16 to 24 times kernel-level speedup over FP16 on NVIDIA GPUs, and 216 to 330 tokens/s end-to-end prefill speedup with lower memory footprint on LLMs. As an algorithm-hardware co-design, BWTA demonstrates practical, low-latency ultra-low-bit inference without sacrificing model quality.