PyLO: Towards Accessible Learned Optimizers in PyTorch

📅 2025-06-12
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Existing learned optimizers (e.g., VeLO) face deployment barriers in the PyTorch ecosystem due to JAX dependency, immature tooling, and lack of validation on real large-model pretraining. This work introduces the first PyTorch-native, large-scale pretraining-oriented learned optimizer library—open-sourced, CUDA-accelerated, and fully compatible with standard training components (e.g., learning rate scheduling, weight decay). It is the first to successfully apply meta-trained optimizers to realistic pretraining tasks, including ViT-B/16. Key contributions include: (i) a lightweight architecture, *small_fc_lopt*, minimizing parameter count and memory footprint; (ii) a meta-training transfer adaptation mechanism enabling cross-task generalization; and (iii) a plug-and-play optimizer encapsulation paradigm for seamless integration. Experiments show that, at batch size 32, ViT-B/16 training throughput increases from 39.36 to 205.59 samples/sec—demonstrating industrial-grade efficiency and measurable co-optimization gains.

Technology Category

Application Category

📝 Abstract
Learned optimizers have been an active research topic over the past decade, with increasing progress toward practical, general-purpose optimizers that can serve as drop-in replacements for widely used methods like Adam. However, recent advances -- such as VeLO, which was meta-trained for 4000 TPU-months -- remain largely inaccessible to the broader community, in part due to their reliance on JAX and the absence of user-friendly packages for applying the optimizers after meta-training. To address this gap, we introduce PyLO, a PyTorch-based library that brings learned optimizers to the broader machine learning community through familiar, widely adopted workflows. Unlike prior work focused on synthetic or convex tasks, our emphasis is on applying learned optimization to real-world large-scale pre-training tasks. Our release includes a CUDA-accelerated version of the small_fc_lopt learned optimizer architecture from (Metz et al., 2022a), delivering substantial speedups -- from 39.36 to 205.59 samples/sec throughput for training ViT B/16 with batch size 32. PyLO also allows us to easily combine learned optimizers with existing optimization tools such as learning rate schedules and weight decay. When doing so, we find that learned optimizers can substantially benefit. Our code is available at https://github.com/Belilovsky-Lab/pylo
Problem

Research questions and friction points this paper is trying to address.

Making learned optimizers accessible in PyTorch
Addressing inaccessibility of JAX-based learned optimizers
Applying learned optimization to real-world large-scale tasks
Innovation

Methods, ideas, or system contributions that make the work stand out.

PyTorch-based library for learned optimizers
CUDA-accelerated small_fc_lopt architecture
Combines learned optimizers with existing tools
🔎 Similar Papers
No similar papers found.