🤖 AI Summary
This work addresses the challenges of catastrophic forgetting and task identification in task-agnostic continual learning by proposing a Functional Task Network (FTN), inspired by the neocortex. FTN employs high-dimensional, self-organizing binary masks to activate deep subnetworks with parameter isolation, enabling unsupervised task discrimination and knowledge retention. A novel three-stage mask generation mechanism efficiently recovers historical subnetworks within a single gradient step, reducing combinatorial search complexity from $O(C(H,K))$ to near-linear $O(H)$. By integrating gradient-driven continuous masks, spatial smoothing kernels, and a k-winner-take-all binarization strategy, the method achieves near-zero forgetting (FTN-Slow) or an excellent trade-off between performance and inference efficiency (FTN-Fast) across three standard benchmarks.
📝 Abstract
Block-sequential continual learning demands that a single model both protect prior solutions from catastrophic forgetting and efficiently infer at inference time which prior solution matches the current input without task labels. We present Functional Task Networks (FTN), a parameter-isolation method inspired by structural and dynamical motifs found in the mammalian neocortex. Similar to mixture-of-experts, this method uses a high dimensional, self-organizing binary mask over a large population of small but deep networks, inspired by dendritic models of pyramidal neurons. The mask is produced by a three-stage procedure: (1) gradient descent on a continuous mask identifies task-relevant neurons, (2) a smoothing kernel biases the result toward spatial contiguity, (3) and k-winner-take-all binarizes the resulting group at a fixed capacity budget. Like mixture-of-experts, each neuron is an independent deep network, so disjoint masks give exactly disjoint gradient updates, providing structural guarantees against catastrophic forgetting. This three-stage procedure recovers the sub-network of a previously-trained task in a single gradient step, providing unsupervised task segmentation at inference time. We test it on three continual-learning benchmarks: (1) a synthetic multi-task classification/regression generator, (2) MNIST with shuffled class labels (pure concept shift), and (3) Permuted MNIST (domain shift). On all three, FTN with fine grained smoothing (FTN-Slow) results in nearly zero forgetting. FTN with a large kernel and only 2 iterations of smoothing (FTN-Fast) trades off some retention for increased speed. We show that the spatial organization mechanism reduces the effective mask search from the combinatorial top-k subset problem in O(C(H,K)) to the complexity of a near-linear scan in O(H) over compact cortical neighborhoods, which is parallelized by the gradient-based update.