Sparton: Fast and Memory-Efficient Triton Kernel for Learned Sparse Retrieval

πŸ“… 2026-03-26
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
This work addresses the memory and computational bottlenecks in Learned Sparse Retrieval (LSR) models caused by the language modeling head generating large-scale vocabulary logit matrices. To overcome this, the authors propose Sparton, an efficient GPU kernel based on Triton that, for the first time, fuses matrix multiplication, ReLU, Log1P, and max-reduction into a single kernel. This design enables early online reduction without explicitly materializing the full logit matrix, substantially reducing memory consumption and accelerating training. Evaluated on Splade, Sparton achieves a 4.8Γ— speedup and a 10Γ— reduction in peak memory usage, enabling batch size scaling by 33% with small vocabularies or up to 26Γ— with large vocabularies, while delivering up to 2.5Γ— faster trainingβ€”all without any loss in model effectiveness.

Technology Category

Application Category

πŸ“ Abstract
State-of-the-art Learned Sparse Retrieval (LSR) models, such as Splade, typically employ a Language Modeling (LM) head to project latent hidden states into a lexically-anchored logit matrix. This intermediate matrix is subsequently transformed into a sparse lexical representation through element-wise operations (ReLU, Log1P) and max-pooling over the sequence dimension. Despite its effectiveness, the LM head creates a massive memory bottleneck due to the sheer size of the vocabulary (V), which can range from 30,000 to over 250,000 tokens in recent models. Materializing this matrix creates a significant memory bottleneck, limiting model scaling. The resulting I/O overhead between operators further throttles throughput and runtime performance. In this paper, we propose Sparton, a fast memory-efficient Triton kernel tailored for the LM head in LSR models. Sparton utilizes a fused approach that integrates the tiled matrix multiplication, ReLU, Log1P, and max-reduction into a single GPU kernel. By performing an early online reduction directly on raw logit tiles, Sparton avoids materializing the full logit matrix in memory. Our experiments demonstrate that the Sparton kernel, in isolation, achieves up to a 4.8x speedup and an order-of-magnitude reduction in peak memory usage compared to PyTorch baselines. Integrated into Splade (|V| ~ 30k), Sparton enables a 33% larger batch size and 14% faster training with no effectiveness loss. On a multilingual backbone (|V| ~ 250k), these gains jump to a 26x larger batch size and 2.5x faster training.
Problem

Research questions and friction points this paper is trying to address.

Learned Sparse Retrieval
memory bottleneck
large vocabulary
logit matrix
training efficiency
Innovation

Methods, ideas, or system contributions that make the work stand out.

Learned Sparse Retrieval
Triton kernel
memory-efficient
fused kernel
LM head
πŸ”Ž Similar Papers
No similar papers found.