🤖 AI Summary
This work addresses the scalability limitations of existing fully homomorphic encryption (FHE)-based approaches for privacy-preserving large language model inference, particularly their inefficiency with long sequences and excessive computational overhead in nonlinear layers due to outlier values. To overcome these challenges, the authors propose an unbalanced chunked prefilling framework that encrypts only the sensitive portion of the input—specifically the last 128 tokens—under the CKKS scheme, while leveraging a hybrid plaintext-ciphertext computation paradigm. The approach integrates a novel homomorphic matrix multiplication algorithm, an efficient polynomial evaluation method, and training-free mitigation strategies such as token shifting and rotation to suppress outliers. Evaluated on an 8×RTX 4090 GPU cluster, the system achieves the first end-to-end private inference for Llama-2-7B on 4096-token inputs, requiring only 85 seconds and 33 seconds per output token for summarization and generation tasks, respectively.
📝 Abstract
As large language models (LLMs) become ubiquitous, privacy concerns pertaining to inference inputs keep growing. In this context, fully homomorphic encryption (FHE) has emerged as a primary cryptographic solution to provide non-interactive confidential LLM inference. Existing solutions scale poorly with the input token length, and hence focus either on small models or larger models with a small number of input tokens. They also suffer from the existence of large outlier values. These values have a strong impact on the evaluation of non-linear layers, leading to large-degree polynomial approximation and thus heavy evaluation costs. We propose an FHE-based private LLM inference solution that allows thousands of input tokens with only a part of them being encrypted: this fits with a scenario where the context is benign and only part of the input is sensitive. To do so, we suggest an unbalanced chunked prefill framework that processes the private and public parts of the input tokens differently. Our framework contains plaintext-plaintext, plaintext-ciphertext and ciphertext-ciphertext computational components. We adopt different strategies and ingredients for each component. We also devise new homomorphic algorithms for specific matrix multiplication and polynomial evaluation tasks encountered during LLM inference. Furthermore, without retraining, we tailor the LLM inference algorithm to reduce the ranges of outlier values: we leverage machine learning strategies (token prepending and rotations) to mitigate the impact of the outliers on non-linear layers. Based on these ingredients, we describe a CKKS-based end-to-end implementation of Llama-2-7B private inference for up to 4096 input tokens, of which the last 128 are encrypted. On a cluster of 8~NVIDIA RTX-4090 GPUs, inference takes 85s for summarization and 33s for generation per output token.