About the job
In this role, you will be tasked with not only maintaining the library but proactively evolving it. You will move beyond simple bug fixing to explore experimental quantization algorithms, adding them to the library before customers even realize they need them. You will operate in a unique environment where you must balance the agility of open-source software with the reliability required by Google-scale production. You will need to obsess over both quality (preserving model accuracy) and performance (optimizing runtime). You will need to be comfortable deep-diving into low-level profiles to debug TPU/GPU bottlenecks, while simultaneously possessing the soft skills to communicate effectively with partner teams in Google DeepMind (GDM) and customer teams across Search and Ads. You will be defining how the world optimizes JAX models.
Responsibilities
Design and implement new quantization features (e.g., post-training quantization (PTQ), quantized training (QT), and on-device machine learning (ODML) support)) to keep pace with the rapidly evolving JAX ecosystem.
Proactively research and implement experimental quantization algorithms (e.g., int2 numerics, dual scale quantization, hadamard transformation) to lead customer adoption rather than just reacting to requests.
Debug and optimize low-level performance issues.
Use accelerated linear algebra (XLA) profiling tools (xprof) to analyze TPU/GPU execution traces, identify bottlenecks, and ensure the lowest-cost implementation of algorithms.
Manage the health of the codebase across two fronts: resolving issues on the public repository and triaging high-priority bugs for internal partners.
Write high-quality documentation, tutorials, and examples for the open-source community to lower the barrier to entry for new users.
Qualifications
Minimum
Bachelor’s degree or equivalent practical experience.
2 years of experience programming in Python or C++.
2 years of experience testing, maintaining, or launching software products, and 1 year of experience with software design and architecture.
2 years of experience with one or more of the following: Speech/audio (e.g., technology duplicating and responding to the human voice), reinforcement learning (e.g., sequential decision making), ML infrastructure, or specialization in another ML field.
2 years of experience with ML infrastructure (e.g., model deployment, model evaluation, optimization, data processing, debugging).
Preferred
Experience with JAX transformations (vmap, pjit, grad) and the underlying XLA compiler stack.
Experience reading high level optimizer (HLO) code to understand exactly how Python code translates to hardware execution is highly valued.
Understanding of theoretical quantization and quantization techniques (PTQ, QAT, weight-only vs. activation) and low-precision numerics (int8, fp8, int4), and the mathematical implications of compression on model convergence.
Ability to interpret low-level performance tools (e.g., xprof, TensorBoard) to identify padding issues, memory fragmentation, or SIMD utilization gaps, profiling and optimizing ML models on TPUs or GPUs.