π€ AI Summary
To address catastrophic forgetting in diffusion models under continual learning, this paper proposes a novel method integrating generative replay with an enhanced Elastic Weight Consolidation (EWC) technique. Methodologically, it leverages the geometric property that gradients at low signal-to-noise ratios are approximately rank-one, enabling construction of a rank-one empirical Fisher matrix for EWCβthereby capturing dominant curvature directions precisely without additional computational overhead, improving parameter stability and mitigating distributional drift. Concurrently, high-fidelity replay samples are generated via the diffusion model itself, facilitating synergistic optimization of parameter sharing and drift constraints. Experiments on MNIST, FashionMNIST, CIFAR-10, and ImageNet-1k demonstrate substantial reductions in forgetting: average FID scores surpass those of pure replay and diagonal-EWC baselines; forgetting is nearly eliminated on MNIST-based tasks, and reduced by approximately 50% on ImageNet-1k.
π Abstract
Catastrophic forgetting remains a central obstacle for continual learning in neural models. Popular approaches -- replay and elastic weight consolidation (EWC) -- have limitations: replay requires a strong generator and is prone to distributional drift, while EWC implicitly assumes a shared optimum across tasks and typically uses a diagonal Fisher approximation. In this work, we study the gradient geometry of diffusion models, which can already produce high-quality replay data. We provide theoretical and empirical evidence that, in the low signal-to-noise ratio (SNR) regime, per-sample gradients become strongly collinear, yielding an empirical Fisher that is effectively rank-1 and aligned with the mean gradient. Leveraging this structure, we propose a rank-1 variant of EWC that is as cheap as the diagonal approximation yet captures the dominant curvature direction. We pair this penalty with a replay-based approach to encourage parameter sharing across tasks while mitigating drift. On class-incremental image generation datasets (MNIST, FashionMNIST, CIFAR-10, ImageNet-1k), our method consistently improves average FID and reduces forgetting relative to replay-only and diagonal-EWC baselines. In particular, forgetting is nearly eliminated on MNIST and FashionMNIST and is roughly halved on ImageNet-1k. These results suggest that diffusion models admit an approximately rank-1 Fisher. With a better Fisher estimate, EWC becomes a strong complement to replay: replay encourages parameter sharing across tasks, while EWC effectively constrains replay-induced drift.