🤖 AI Summary
Large-scale recommendation models suffer from accuracy degradation and reduced training efficiency when directly adopting FP8 low-precision computation, due to their numerical sensitivity, dense small-matrix operations, and high communication overhead. This work proposes LoKA, a framework that enables efficient and practical FP8 training through co-design of system and model components. LoKA first performs statistical-driven online performance profiling based on real data distributions (LoKA Probe), then introduces reusable model modifications to enhance numerical stability (LoKA Mods), and finally employs a runtime scheduler that dynamically selects the optimal FP8 kernel under accuracy constraints (LoKA Dispatch). Experiments demonstrate that LoKA significantly accelerates FP8 training while preserving model accuracy, offering the first viable low-precision training solution for large-scale recommendation systems.
📝 Abstract
Recent GPU generations deliver significantly higher FLOPs using lower-precision arithmetic, such as FP8. While successfully applied to large language models (LLMs), its adoption in large recommendation models (LRMs) has been limited. This is because LRMs are numerically sensitive, dominated by small matrix multiplications (GEMMs) followed by normalization, and trained in communication-intensive environments. Applying FP8 directly to LRMs often degrades model quality and prolongs training time. These challenges are inherent to LRM workloads and cannot be resolved merely by introducing better FP8 kernels. Instead, a system-model co-design approach is needed to successfully integrate FP8. We present LoKA (Low-precision Kernel Applications), a framework that makes FP8 practical for LRMs through three principles: profile under realistic distributions to know where low precision is safe, co-design model components with hardware to expand where it is safe, and orchestrate across kernel libraries to maximize the gains. Concretely, LoKA Probe is a statistically grounded, online benchmarking method that learns activation and weight statistics, and quantifies per-layer errors. This process pinpoints safe and unsafe, fast and slow sites for FP8 adoption. LoKA Mods is a set of reusable model adaptations that improve both numerical stability and execution efficiency with FP8. LoKA Dispatch is a runtime that leverages the statistical insights from LoKA Probe to select the fastest FP8 kernel that satisfies the accuracy requirements.