🤖 AI Summary
Modern deep learning compilers lack a unified abstraction for jointly expressing tensor layouts across device distributions and intra-device memory hierarchies, hindering effective optimization. This work proposes the Axe layout abstraction, which maps logical tensor coordinates to a multi-axis physical space through named axes, thereby unifying support for tiling, sharding, replication, and offsetting. Building upon this abstraction, the authors design a multi-granularity, distribution-aware domain-specific language (DSL) and a hardware-aware compiler that integrates fine-grained thread-level control with collective communication primitives. Their approach achieves, for the first time, a unified layout representation spanning from device grids down to thread level, significantly improving both the performance—approaching that of hand-tuned implementations—and portability of the generated code.
📝 Abstract
Scaling modern deep learning workloads demands coordinated placement of data and compute across device meshes, memory hierarchies, and heterogeneous accelerators. We present Axe Layout, a hardware-aware abstraction that maps logical tensor coordinates to a multi-axis physical space via named axes. Axe unifies tiling, sharding, replication, and offsets across inter-device distribution and on-device layouts, enabling collective primitives to be expressed consistently from device meshes to threads. Building on Axe, we design a multi-granularity, distribution-aware DSL and compiler that composes thread-local control with collective operators in a single kernel. Experiments show that our unified approach can bring performance close to hand-tuned kernels on across latest GPU devices and multi-device environments and accelerator backends.