Avoid Catastrophic Forgetting with Rank-1 Fisher from Diffusion Models

πŸ“… 2025-09-27
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
To address catastrophic forgetting in diffusion models under continual learning, this paper proposes a novel method integrating generative replay with an enhanced Elastic Weight Consolidation (EWC) technique. Methodologically, it leverages the geometric property that gradients at low signal-to-noise ratios are approximately rank-one, enabling construction of a rank-one empirical Fisher matrix for EWCβ€”thereby capturing dominant curvature directions precisely without additional computational overhead, improving parameter stability and mitigating distributional drift. Concurrently, high-fidelity replay samples are generated via the diffusion model itself, facilitating synergistic optimization of parameter sharing and drift constraints. Experiments on MNIST, FashionMNIST, CIFAR-10, and ImageNet-1k demonstrate substantial reductions in forgetting: average FID scores surpass those of pure replay and diagonal-EWC baselines; forgetting is nearly eliminated on MNIST-based tasks, and reduced by approximately 50% on ImageNet-1k.

Technology Category

Application Category

πŸ“ Abstract
Catastrophic forgetting remains a central obstacle for continual learning in neural models. Popular approaches -- replay and elastic weight consolidation (EWC) -- have limitations: replay requires a strong generator and is prone to distributional drift, while EWC implicitly assumes a shared optimum across tasks and typically uses a diagonal Fisher approximation. In this work, we study the gradient geometry of diffusion models, which can already produce high-quality replay data. We provide theoretical and empirical evidence that, in the low signal-to-noise ratio (SNR) regime, per-sample gradients become strongly collinear, yielding an empirical Fisher that is effectively rank-1 and aligned with the mean gradient. Leveraging this structure, we propose a rank-1 variant of EWC that is as cheap as the diagonal approximation yet captures the dominant curvature direction. We pair this penalty with a replay-based approach to encourage parameter sharing across tasks while mitigating drift. On class-incremental image generation datasets (MNIST, FashionMNIST, CIFAR-10, ImageNet-1k), our method consistently improves average FID and reduces forgetting relative to replay-only and diagonal-EWC baselines. In particular, forgetting is nearly eliminated on MNIST and FashionMNIST and is roughly halved on ImageNet-1k. These results suggest that diffusion models admit an approximately rank-1 Fisher. With a better Fisher estimate, EWC becomes a strong complement to replay: replay encourages parameter sharing across tasks, while EWC effectively constrains replay-induced drift.
Problem

Research questions and friction points this paper is trying to address.

Address catastrophic forgetting in neural continual learning models
Improve Fisher approximation for better gradient geometry estimation
Combine replay methods with enhanced EWC to reduce distributional drift
Innovation

Methods, ideas, or system contributions that make the work stand out.

Rank-1 Fisher approximation captures dominant curvature direction
Combines replay-based approach with EWC penalty
Mitigates catastrophic forgetting in diffusion model training
πŸ”Ž Similar Papers
No similar papers found.
Z
Zekun Wang
College of Computing, Georgia Institute of Technology, Atlanta, GA 30332, USA
A
Anant Gupta
College of Computing, Georgia Institute of Technology, Atlanta, GA 30332, USA
Z
Zihan Dong
College of Computing, Georgia Institute of Technology, Atlanta, GA 30332, USA
Christopher J. MacLellan
Christopher J. MacLellan
Assistant Professor, Georgia Institute of Technology
Cognitive SystemsArtificial Intelligence in EducationHuman-AI TeamingConcept Formation