🤖 AI Summary
To address the challenge of simultaneously achieving high accuracy, low latency, and strong privacy in large language model (LLM) inference services, this paper proposes HAT—a novel hierarchical architecture for split inference. HAT introduces a “Hat-shaped” model partitioning scheme, decoupling the LLM into a lightweight on-device small language model (SLM)—comprising input and output submodels—and a cloud-resident backbone decoder. It enables efficient inference via hidden-state-level collaboration and speculative decoding. We further design a hidden-state exchange–based device-cloud coordination mechanism and a dynamic chunked long-prompt scheduling strategy, ensuring end-to-end data privacy (i.e., no raw user data leaves the device) while optimizing computational and communication efficiency. Experiments on a testbed comprising 30 Jetson edge devices and an 8×A6000 GPU server demonstrate that HAT reduces time-to-first-token (TTFT) by 41%–54% and time-per-token (TPT) by 41%–77%, significantly outperforming baseline split-inference approaches.
📝 Abstract
Recent advancements in large language models (LLMs) have catalyzed a substantial surge in demand for LLM services. While traditional cloud-based LLM services satisfy high-accuracy requirements, they fall short in meeting critical demands for low delay and enhanced privacy. To address these limitations, we propose HAT, a novel device-cloud collaborative inference framework that leverages the complementary strengths of U-shaped inference and speculative decoding. HAT partitions the LLM into three submodels, and the input and output submodels, stacked with a lightweight adapter network, are deployed as a small language model (SLM) on each end device. Meanwhile, the middle submodel, encompassing the majority of the LLM's decoder layers, is hosted in the cloud to perform speculative decoding with on-device SLMs. During inference, HAT exchanges hidden states (rather than raw tokens) of input or draft tokens between devices and the cloud, thereby incurring substantial communication delays. Besides, processing hidden states of long prompts will exacerbate computation delays in the cloud, further compromising inference efficiency. To improve efficiency, we introduce a prompt chunking mechanism that segments long prompts into shorter chunks, enabling parallel transmission and processing. Furthermore, HAT is implemented to dynamically determine optimal chunk sizes for devices handling long prompts, thereby improving overall inference speed. Extensive experiments are conducted on a physical testbed comprising 30 NVIDIA Jetson devices and a server with 8 NVIDIA A6000 GPUs. Experimental results demonstrate that HAT achieves promising performance improvements, reducing TTFT by 41% to 54% and TBT by 41% to 77% compared to the baselines.