π€ AI Summary
Addressing the challenge of balancing differential privacy (DP) guarantees and computational efficiency in large language model (LLM) training, this paper introduces FlashDPβthe first framework to fully fuse per-layer DP-SGD operations (gradient clipping, noise injection, and aggregation) into a single forward/backward pass. Its core innovation lies in eliminating explicit per-sample gradient storage: instead, it enables layer-wise fused computation with cache-friendly memory access patterns and fine-grained memory management, completing all DP operations within one gradient computation. Evaluated on four A100 GPUs for pretraining Llama-13B under standard DP-SGD privacy budgets, FlashDP matches the accuracy of baseline DP-SGD while achieving 90% of the throughput of non-private training. It reduces memory movement by 50% and redundant computation by 20%. The implementation is publicly available.
π Abstract
As large language models (LLMs) increasingly underpin technological advancements, the privacy of their training data emerges as a critical concern. Differential Privacy (DP) serves as a rigorous mechanism to protect this data, yet its integration via Differentially Private Stochastic Gradient Descent (DP-SGD) introduces substantial challenges, primarily due to the complexities of per-sample gradient clipping. Current explicit methods, such as Opacus, necessitate extensive storage for per-sample gradients, significantly inflating memory requirements. Conversely, implicit methods like GhostClip reduce storage needs by recalculating gradients multiple times, which leads to inefficiencies due to redundant computations. This paper introduces FlashDP, an innovative cache-friendly per-layer DP-SGD that consolidates necessary operations into a single task, calculating gradients only once in a fused manner. This approach not only diminishes memory movement by up to extbf{50%} but also cuts down redundant computations by extbf{20%}, compared to previous methods. Consequently, FlashDP does not increase memory demands and achieves a extbf{90%} throughput compared to the Non-DP method on a four-A100 system during the pre-training of the Llama-13B model, while maintaining parity with standard per-layer clipped DP-SGD in terms of accuracy. These advancements establish FlashDP as a pivotal development for efficient and privacy-preserving training of LLMs. FlashDP's code has been open-sourced in https://github.com/kaustpradalab/flashdp.