🤖 AI Summary
This work addresses the computational and memory bottlenecks in large language model (LLM) reinforcement learning during rollout, where long-sequence generation incurs high attention costs and KV cache memory pressure, while FP8 low-precision computation introduces training-inference inconsistency and policy instability. We present the first end-to-end FP8 rollout system for LLM reinforcement learning, integrating block-wise W8A8 quantization, FP8 KV caching, progressive QKV scale recalibration, and a token-level correction mechanism based on importance sampling (TIS/MIS) to mitigate precision-induced bias. Implemented within the veRL ecosystem, our approach is compatible with FSDP/Megatron-LM for training and vLLM/SGLang for inference, achieving up to 44% higher rollout throughput on both dense and MoE models while matching the learning performance of BF16 baselines.
📝 Abstract
Reinforcement learning (RL) for large language models (LLMs) is increasingly bottlenecked by rollout (generation), where long output sequence lengths make attention and KV-cache memory dominate end-to-end step time. FP8 offers an attractive lever for accelerating RL by reducing compute cost and memory traffic during rollout, but applying FP8 in RL introduces unique engineering and algorithmic challenges: policy weights change every step (requiring repeated quantization and weight synchronization into the inference engine) and low-precision rollouts can deviate from the higher-precision policy assumed by the trainer, causing train-inference mismatch and potential instability. This report presents a practical FP8 rollout stack for LLM RL, implemented in the veRL ecosystem with support for common training backends (e.g., FSDP/Megatron-LM) and inference engines (e.g., vLLM/SGLang). We (i) enable FP8 W8A8 linear-layer rollout using blockwise FP8 quantization, (ii) extend FP8 to KV-cache to remove long-context memory bottlenecks via per-step QKV scale recalibration, and (iii) mitigate mismatch using importance-sampling-based rollout correction (token-level TIS/MIS variants). Across dense and MoE models, these techniques deliver up to 44% rollout throughput gains while preserving learning behavior comparable to BF16 baselines.