🤖 AI Summary
This work addresses catastrophic forgetting in continual learning by proposing a novel theoretical framework grounded in latent representation identifiability: forgetting is formalized as a failure to identify shared latent variables across tasks. We are the first to cast catastrophic interference as a latent identifiability problem and prove that identifying and constraining task-shared latent variables theoretically guarantees minimal forgetting. Building on this insight, we devise a two-stage training strategy: (1) learning latent representations via maximum likelihood estimation, followed by (2) enforcing latent identifiability through KL-divergence regularization. Our approach bridges identifiability theory with representation learning, yielding substantial reductions in forgetting on both synthetic and multiple standard benchmarks—including Split MNIST, Permuted MNIST, and CIFAR-100—while maintaining strong theoretical rigor and empirical performance gains.
📝 Abstract
Catastrophic interference, also known as catastrophic forgetting, is a fundamental challenge in machine learning, where a trained learning model progressively loses performance on previously learned tasks when adapting to new ones. In this paper, we aim to better understand and model the catastrophic interference problem from a latent representation learning point of view, and propose a novel theoretical framework that formulates catastrophic interference as an identification problem. Our analysis demonstrates that the forgetting phenomenon can be quantified by the distance between partial-task aware (PTA) and all-task aware (ATA) setups. Building upon recent advances in identifiability theory, we prove that this distance can be minimized through identification of shared latent variables between these setups. When learning, we propose our method ourmeos with two-stage training strategy: First, we employ maximum likelihood estimation to learn the latent representations from both PTA and ATA configurations. Subsequently, we optimize the KL divergence to identify and learn the shared latent variables. Through theoretical guarantee and empirical validations, we establish that identifying and learning these shared representations can effectively mitigate catastrophic interference in machine learning systems. Our approach provides both theoretical guarantees and practical performance improvements across both synthetic and benchmark datasets.