๐ค AI Summary
This work addresses the high computational and memory overhead of attention mechanisms, which severely limits Transformer inference efficiency on edge devices. To this end, we propose the Function-Aware Attention Replacement (FAR) frameworkโthe first systematic study demonstrating that attention can be replaced during inference by lightweight sequence-mapping modules without compromising semantic relationships among tokens. Methodologically, FAR employs a multi-head LSTM architecture, synergistically optimized via block-level response distillation and global structured pruning, and supports seamless adaptation of pre-trained Transformers. Evaluated on DeiT models, FAR maintains accuracy on ImageNet and downstream tasks while reducing parameters by 32% and end-to-end latency by 27%, all while preserving critical token-level dependencies modeled by attention. This work establishes a new paradigm for efficient Transformer deployment on resource-constrained edge platforms.
๐ Abstract
While transformers excel across vision and language pretraining tasks, their reliance on attention mechanisms poses challenges for inference efficiency, especially on edge and embedded accelerators with limited parallelism and memory bandwidth. Hinted by the observed redundancy of attention at inference time, we hypothesize that though the model learns complicated token dependency through pretraining, the inference-time sequence-to-sequence mapping in each attention layer is actually ''simple'' enough to be represented with a much cheaper function. In this work, we explore FAR, a Function-preserving Attention Replacement framework that replaces all attention blocks in pretrained transformers with learnable sequence-to-sequence modules, exemplified by an LSTM. FAR optimize a multi-head LSTM architecture with a block-wise distillation objective and a global structural pruning framework to achieve a family of efficient LSTM-based models from pretrained transformers. We validate FAR on the DeiT vision transformer family and demonstrate that it matches the accuracy of the original models on ImageNet and multiple downstream tasks with reduced parameters and latency. Further analysis shows that FAR preserves the semantic token relationships and the token-to-token correlation learned in the transformer's attention module.