🤖 AI Summary
Training large-scale sparse autoencoders (SAEs) incurs prohibitive computational and memory overhead—especially with large dictionary sizes—hindering scalability and interpretability.
Method: We propose KronSAE, a novel architecture featuring (1) a Kronecker-product-based latent-space decomposition framework that drastically reduces linear transformation parameters; (2) mAND, a differentiable activation function approximating binary AND logic to enforce high sparsity while enhancing semantic interpretability of learned features; and (3) sparse-aware kernel optimizations for efficient inference.
Results: On neuron feature disentanglement in language models, KronSAE significantly reduces memory and FLOP consumption, enabling training with substantially larger dictionaries. It simultaneously improves reconstruction accuracy and direction-level interpretability—measured via feature alignment and causal mediation analysis—without compromising sparsity or fidelity.
📝 Abstract
Sparse Autoencoders (SAEs) have demonstrated significant promise in interpreting the hidden states of language models by decomposing them into interpretable latent directions. However, training SAEs at scale remains challenging, especially when large dictionary sizes are used. While decoders can leverage sparse-aware kernels for efficiency, encoders still require computationally intensive linear operations with large output dimensions. To address this, we propose KronSAE, a novel architecture that factorizes the latent representation via Kronecker product decomposition, drastically reducing memory and computational overhead. Furthermore, we introduce mAND, a differentiable activation function approximating the binary AND operation, which improves interpretability and performance in our factorized framework.