🤖 AI Summary
Direct training of Spiking Neural Networks (SNNs) suffers from membrane potential distribution drift across timesteps, causing threshold misalignment, imbalanced spiking activity, and severe gradient attenuation—particularly in deep layers.
Method: We propose a two-stage cooperative learning framework: (1) a forward pass employing an adaptive firing threshold dynamically calibrated to the evolving membrane potential distribution, enabling spatiotemporally aligned spike generation; and (2) a backward pass with gradient dynamic optimization, wherein surrogate gradients are spatiotemporally scaled during backpropagation to mitigate deep-layer gradient vanishing.
Contribution/Results: Our method introduces no additional trainable parameters, significantly enhancing training stability and convergence speed. It achieves state-of-the-art (SOTA) accuracy on multiple benchmark datasets. Moreover, it balances spike rates across timesteps and improves gradient coverage in deep layers by 32.7%, effectively reconciling accuracy, computational efficiency, and scalability.
📝 Abstract
Brain-inspired spiking neural networks (SNNs) are recognized as a promising avenue for achieving efficient, low-energy neuromorphic computing. Direct training of SNNs typically relies on surrogate gradient (SG) learning to estimate derivatives of non-differentiable spiking activity. However, during training, the distribution of neuronal membrane potentials varies across timesteps and progressively deviates toward both sides of the firing threshold. When the firing threshold and SG remain fixed, this may lead to imbalanced spike firing and diminished gradient signals, preventing SNNs from performing well. To address these issues, we propose a novel dual-stage synergistic learning algorithm that achieves forward adaptive thresholding and backward dynamic SG. In forward propagation, we adaptively adjust thresholds based on the distribution of membrane potential dynamics (MPD) at each timestep, which enriches neuronal diversity and effectively balances firing rates across timesteps and layers. In backward propagation, drawing from the underlying association between MPD, threshold, and SG, we dynamically optimize SG to enhance gradient estimation through spatio-temporal alignment, effectively mitigating gradient information loss. Experimental results demonstrate that our method achieves significant performance improvements. Moreover, it allows neurons to fire stable proportions of spikes at each timestep and increases the proportion of neurons that obtain gradients in deeper layers.