🤖 AI Summary
This work addresses the longstanding trade-off in spiking neural network training between gradient accuracy and neuronal model flexibility: discrete-time approaches introduce gradient bias, while continuous-time methods are limited to simplistic neuron models. To overcome this, we propose Eventax, a framework built on JAX and Diffrax that integrates differentiable numerical ODE solvers with event-driven dynamics, enabling, for the first time, end-to-end exact gradient-based training of arbitrarily complex neuron models defined by ordinary differential equations—including multi-compartment cortical neurons. Users need only specify the neuron’s dynamics, spiking condition, and reset rules to automatically obtain unbiased gradients. We demonstrate successful training of diverse models—LIF, QIF, EIF, Izhikevich, and EGRU—on benchmarks such as Yin-Yang and MNIST, and further validate the feasibility of learning with biologically detailed multi-compartment human cortical neuron models.
📝 Abstract
Existing frameworks for gradient-based training of spiking neural networks face a trade-off: discrete-time methods using surrogate gradients support arbitrary neuron models but introduce gradient bias and constrain spike-time resolution, while continuous-time methods that compute exact gradients require analytical expressions for spike times and state evolution, restricting them to simple neuron types such as Leaky Integrate and Fire (LIF). We introduce the Eventax framework, which resolves this trade-off by combining differentiable numerical ODE solvers with event-based spike handling. Built in JAX, our frame-work uses Diffrax ODE-solvers to compute gradients that are exact with respect to the forward simulation for any neuron model defined by ODEs . It also provides a simple API where users can specify just the neuron dynamics, spike conditions, and reset rules. Eventax prioritises modelling flexibility, supporting a wide range of neuron models, loss functions, and network architectures, which can be easily extended. We demonstrate Eventax on multiple benchmarks, including Yin-Yang and MNIST, using diverse neuron models such as Leaky Integrate-and-fire (LIF), Quadratic Integrate-and-fire (QIF), Exponential integrate-and-fire (EIF), Izhikevich and Event-based Gated Recurrent Unit (EGRU) with both time-to-first-spike and state-based loss functions, demonstrating its utility for prototyping and testing event-based architectures trained with exact gradients. We also demonstrate the application of this framework for more complex neuron types by implementing a multi-compartment neuron that uses a model of dendritic spikes in human layer 2/3 cortical Pyramidal neurons for computation. Code available at https://github.com/efficient-scalable-machine-learning/eventax.