Ragged Paged Attention: A High-Performance and Flexible LLM Inference Kernel for TPU

📅 2026-04-16
📈 Citations: 0
Influential: 0
📄 PDF

career value

213K/year
🤖 AI Summary
Existing large language model inference kernels struggle to efficiently leverage TPU architectures, particularly under dynamic and irregular inference workloads. This work proposes a high-performance attention kernel tailored for TPUs, employing fine-grained dynamic slicing and a soft-pipelined fusion of KV cache updates with attention computation. Coupled with a distribution-aware compilation strategy, the approach generates specialized kernels for decoding, prefilling, and hybrid execution phases. Implemented using Pallas and Mosaic, the design achieves 86% memory bandwidth utilization during decoding and 73% FLOPs utilization in the prefill phase for Llama-3 8B on TPU v7x, and has been integrated into vLLM and SGLang as the primary TPU backend.

Technology Category

Application Category

📝 Abstract
Large Language Model (LLM) deployment is increasingly shifting to cost-efficient accelerators like Google's Tensor Processing Units (TPUs), prioritizing both performance and total cost of ownership (TCO). However, existing LLM inference kernels and serving systems remain largely GPU-centric, and there is no well-established approach for efficiently mapping LLM workloads onto TPU architectures--particularly under the dynamic and ragged execution patterns common in modern serving. In this paper, we present Ragged Paged Attention (RPA), a high-performance and flexible attention kernel for TPUs, implemented using Pallas and Mosaic. RPA addresses these challenges through three key techniques: (1) fine-grained tiling to enable efficient dynamic slicing over ragged memory, (2) a custom software pipeline that fuses KV cache updates with attention computation, and (3) a distribution-aware compilation strategy that generates specialized kernels for decode, prefill, and mixed workloads. Evaluated on Llama 3 8B on TPU7x, RPA achieves up to 86% memory bandwidth utilization (MBU) in decode and 73% model FLOPs utilization (MFU) in prefill. Integrated as the primary TPU backend in vLLM and SGLang, RPA provides a production-grade foundation for efficient TPU inference and offers practical insights into kernel design.
Problem

Research questions and friction points this paper is trying to address.

LLM inference
TPU
ragged execution
attention kernel
memory bandwidth utilization
Innovation

Methods, ideas, or system contributions that make the work stand out.

Ragged Paged Attention
TPU
LLM inference
KV cache fusion
distribution-aware compilation