🤖 AI Summary
This work addresses the high computational cost of iterative denoising in action diffusion models, which hinders their adaptability to dynamic policy changes induced by perceptual shifts and multi-step rollouts in open environments. To overcome this limitation, the authors propose a test-time sparsification approach that decouples encoding and pruning from the autoregressive denoising loop, employing a lightweight shared pruner to dynamically predict prunable residual components. An asynchronous parallel inference pipeline is introduced, augmented with an omnidirectional feature reuse mechanism that leverages information across current forward passes, historical denoising steps, and prior rollouts. The method is trained via stepwise supervised learning using only a small number of trajectories. This strategy achieves a 92% reduction in FLOPs and a 5× speedup in action generation—enabling a 47.5 Hz inference rate—while preserving task performance.
📝 Abstract
Action diffusion excels at high-fidelity action generation but incurs heavy computational costs owing to its iterative denoising nature. Despite current technologies showing promise in accelerating diffusion transformers by reusing the cached features, they struggle to adapt to policy dynamics arising from diverse perceptions and multi-round rollout iterations in open environments. We propose test-time sparsity to tackle this challenge, which aims to accelerate action diffusion by dynamically predicting prunable residual computations for each model forward at test time. However, two bottlenecks remain in this paradigm: 1) repetitive conditional encoding and pruning offset most potential speed gains, and 2) the features cached from previous denoising timesteps cannot constrain large pruning errors under aggressive sparsity. To address the first bottleneck, we design a highly parallelized inference pipeline that minimizes the non-decoder delay to milliseconds. Specifically, we first design a lightweight pruner that shares the encoder with the diffusion transformer. Then, we decouple the encoding and pruning from the autoregressive denoising loop by processing all denoising timesteps in parallel, and overlap the pruner with the decoder forward inference through asynchronism. To overcome the second bottleneck, we introduce an omnidirectional reusing strategy, which achieves 95% sparsity by selectively reusing features cached from the current forward, previous denoising timesteps, and earlier rollout iterations. To learn the rollout-level reusing strategies, we sample a few action trajectories to supervise the sparsified diffusion step by step. Extensive experiments demonstrate that our method reduces FLOPs by 92% and accelerates action generation by 5x, achieving lossless performance with an inference frequency of 47.5 Hz. Our code is available at https://github.com/ky-ji/Test-time-Sparsity.