🤖 AI Summary
This work addresses the lack of JAX ecosystem support in existing reinforcement learning benchmarks, which hinders efficient research on first-person visual tasks—particularly under partial observability and exploration challenges that lead to slow iteration. To bridge this gap, we introduce JAXenstein, the first high-performance, differentiable, and batchable first-person visual reinforcement learning environment built entirely in JAX, faithfully reimplementing the Wolfenstein 3D rendering engine. By fully leveraging JAX’s automatic differentiation and parallelization capabilities, JAXenstein dramatically improves sampling efficiency and enables large-scale concurrent experimentation. Our benchmark outperforms existing platforms in both speed and scalability, thereby filling a critical void in the JAX ecosystem for such tasks.
📝 Abstract
The progression of reinforcement learning algorithms have been driven by challenging benchmarks. The rate in which a researcher can iterate on a problem setting directly impacts the speed of algorithm development. Modern machine learning has produced tools that allow for fast and scalable algorithm development like the JAX library. With the availability of these tools, a serious bottleneck in algorithm development is the availability of large and complex domains for experimentation. Most notably, the JAX reinforcement learning ecosystem does not have any benchmarks that test visual first-person tasks; these domains are crucial for testing both exploration and an agent's ability to overcome partial observability. We introduce JAXenstein: an open-source JAX-based benchmark that implements the Wolfenstein 3D rendering engine for fast and scalable experimentation in visual first-person tasks. JAXenstein is several times faster than comparable vision-based benchmarks, and is easily extensible to more complex first-person domains.