🤖 AI Summary
This work addresses the limited generalization of brain decoding models across brain regions, recording sessions, behavioral paradigms, and subjects by proposing the Robust Pretrained Neural Transformer (RPNT). RPNT integrates experimental metadata through multidimensional rotary positional encoding, captures local temporal structure via a convolutional kernel-based contextual attention mechanism, and employs a robust self-supervised pretraining objective combining causal masking with contrastive learning. Evaluated on multi-session, multi-task electrophysiological data from both microwire arrays and high-density Neuropixels probes, RPNT consistently outperforms existing methods across diverse cross-domain motor decoding tasks, achieving substantial improvements in both decoding accuracy and generalization performance.
📝 Abstract
Brain decoding aims to interpret and translate neural activity into behaviors. As such, it is imperative that decoding models are able to generalize across variations, such as recordings from different brain sites, distinct sessions, different types of behavior, and a variety of subjects. Current models can only partially address these challenges and warrant the development of pretrained neural transformer models capable to adapt and generalize. In this work, we propose RPNT - Robust Pretrained Neural Transformer, designed to achieve robust generalization through pretraining, which in turn enables effective finetuning given a downstream task. In particular, RPNT unique components include 1) Multidimensional rotary positional embedding (MRoPE) to aggregate experimental metadata such as site coordinates, session name and behavior types; 2) Context-based attention mechanism via convolution kernels operating on global attention to learn local temporal structures for handling non-stationarity of neural population activity; 3) Robust self-supervised learning (SSL) objective with uniform causal masking strategies and contrastive representations. We pretrained two separate versions of RPNT on distinct datasets a) Multi-session, multi-task, and multi-subject microelectrode benchmark; b) Multi-site recordings using high-density Neuropixel 1.0 probes. The datasets include recordings from the dorsal premotor cortex (PMd) and from the primary motor cortex (M1) regions of nonhuman primates (NHPs) as they performed reaching tasks. After pretraining, we evaluated the generalization of RPNT in cross-session, cross-type, cross-subject, and cross-site downstream behavior decoding tasks. Our results show that RPNT consistently achieves and surpasses the decoding performance of existing decoding models in all tasks.