🤖 AI Summary
Existing dataset distillation methods are constrained by gradient vanishing in diffusion models, often resorting to GANs or VAEs and thus struggling with differentiable optimization at high resolutions. This paper proposes Latent-space Diffusion Distillation (LDD), the first approach to embed dataset distillation within the latent space of a pre-trained diffusion model. LDD introduces a gradient-flow enhancement mechanism and supports tunable denoising steps to balance efficiency and fidelity, thereby overcoming architectural coupling and enabling high-resolution (128×128 and 256×256) synthesis. Leveraging differentiable synthetic optimization and multi-scale feature alignment, LDD achieves significant downstream improvements: on ImageNet subsets, it attains +4.8 and +4.2 percentage points in classification accuracy using only 0.1 synthetic image per class—substantially outperforming state-of-the-art methods.
📝 Abstract
Machine learning traditionally relies on increasingly larger datasets. Yet, such datasets pose major storage challenges and usually contain non-influential samples, which could be ignored during training without negatively impacting the training quality. In response, the idea of distilling a dataset into a condensed set of synthetic samples, i.e., a distilled dataset, emerged. One key aspect is the selected architecture, usually ConvNet, for linking the original and synthetic datasets. However, the final accuracy is lower if the employed model architecture differs from that used during distillation. Another challenge is the generation of high-resolution images (128x128 and higher). To address both challenges, this paper proposes Latent Dataset Distillation with Diffusion Models (LD3M) that combine diffusion in latent space with dataset distillation. Our novel diffusion process is tailored for this task and significantly improves the gradient flow for distillation. By adjusting the number of diffusion steps, LD3M also offers a convenient way of controlling the trade-off between distillation speed and dataset quality. Overall, LD3M consistently outperforms state-of-the-art methods by up to 4.8 p.p. and 4.2 p.p. for 1 and 10 images per class, respectively, and on several ImageNet subsets and high resolutions (128x128 and 256x256).