🤖 AI Summary
This study addresses the lack of reproducible end-to-end fine-tuning and deployment pipelines for Gemma-4 31B on TPUs, as well as the unclear performance and cost trade-offs between TPUs and GPUs for large language model training and inference. We present the first successful implementation of LoRA fine-tuning and inference serving for Gemma-4 31B on Google Cloud TPU, fully migrating an existing PyTorch GPU workflow to the JAX+Tunix/Qwix stack and introducing Orbax-to-safetensors conversion alongside vLLM-TPU Docker deployment. Experimental results demonstrate that, compared to an H100 GPU baseline, TPU v5p-8 achieves 1.61× faster training with 2.12× lower cost, while TPU v6e-8 delivers comparable inference throughput and 50% lower first-token latency. Overall, the combined training and serving cost is reduced by 1.82×, filling a critical gap in open-source TPU tooling and establishing a key technical pathway for cross-platform large model adaptation.
📝 Abstract
We present the first end-to-end demonstration of fine-tuning and serving Google's Gemma 4 31B model on TPU hardware, providing an empirical comparison of TPU and GPU platforms for large language model adaptation. Using LoRA on a Google TPU v5p-8 for training and TPU v6e-8 (Trillium) for inference, we document the full set of code-level adaptations required to port a GPU-native training recipe, built on PyTorch, HuggingFace TRL, and FSDP, to the JAX + Tunix/Qwix stack. These adaptations span mesh configuration, LoRA module naming conventions, sharding annotation corrections, gradient checkpointing, data pipeline restructuring, and a custom Orbax-to-safetensors checkpoint merging procedure.
For inference, we detail the vLLM-TPU Docker setup necessary to serve Gemma 4 on v6e-8 and characterize the resulting latency and throughput profile. Compared with a 2xH100 GPU baseline under identical hyperparameters, TPU training completes 1.61x faster at 2.12x lower cost. Inference throughput is within 3% across platforms, while TPU achieves 2x lower time-to-first-token (235 ms vs. 475 ms). Together, the TPU configuration is 1.82x cheaper for a representative train-plus-service workload.
Our work removes a critical gap in the open tooling ecosystem and provides practitioners with a reproducible, production-ready recipe for Gemma 4 deployment on TPU infrastructure.