🤖 AI Summary
Existing mixed-precision training methods struggle to balance training stability and model quality for large language models on sub-byte-precision GPUs, often leading to degraded generalization and convergence difficulties. This work proposes SNIP, the first fine-grained adaptive mixed-precision training framework tailored for sub-byte hardware. SNIP periodically collects statistical information from activations, gradients, and optimizer states to quantify divergence in forward loss and backward weight updates, formulating an integer linear programming (ILP) problem that dynamically optimizes per-layer precision configurations. Evaluated on Llama models ranging from 1B to 70B parameters, SNIP reduces FLOPs by up to 80% with negligible overhead while consistently preserving model quality throughout training, thereby significantly enhancing both training efficiency and stability.
📝 Abstract
Training large language models (LLMs) efficiently while preserving model quality poses significant challenges, particularly with subbyte precision supported by state-of-the-art GPUs. Current mixed-precision training approaches either apply uniform precision to all GEMM operations or rely on heuristic-based methods that fail to generalize during training, leading to suboptimal convergence and instability. To address these challenges, this paper introduces SNIP, a fine-grained adaptive mixed-precision training framework for LLM pretraining that supports subbyte precision. SNIP periodically collects statistics on activations, gradients, and optimizer states to assess the precision loss impact on model quality. We define two key metrics: loss divergence in the forward pass, caused by quantization-induced increases in training loss, and weight divergence in the backward pass, which measures error propagation through gradients affecting model updates. These metrics guide an Integer Linear Programming (ILP) problem that systematically optimizes layerwise precision to minimize overall quality loss while meeting efficiency targets. Experiments on 1B, 3B, 7B and 70B Llama-like models demonstrate that SNIP consistently outperforms existing baselines, reducing FLOPs by up to 80% while preserving model quality across different model sizes and training phases with minimal computational overhead.