π€ AI Summary
Inserting sparse autoencoders (SAEs) into language models (LMs) incurs substantial increases in cross-entropy loss, high training overhead, and degraded downstream performance.
Method: We propose LoRA-SAE, the first framework to jointly optimize LMs and SAEs by applying low-rank adaptation (LoRA) to fine-tune the LM for compatibility with pre-trained SAEsβreplacing conventional post-hoc SAE training. Our approach supports parallel adaptation of multiple SAEs and cross-layer deployment.
Results: Evaluated on Gemma-2-2B and Llama-3.2-1B, LoRA-SAE reduces SAE-induced cross-entropy loss increase by 30β55%, accelerates training by 2Γβ20Γ, and improves downstream task metrics. It achieves Pareto-optimal trade-offs between performance and efficiency, offering a scalable, computationally efficient pathway toward model-level interpretability.
π Abstract
Sparse autoencoders (SAEs) decompose language model representations into a sparse set of linear latent vectors. Recent works have improved SAEs using language model gradients, but these techniques require many expensive backward passes during training and still cause a significant increase in cross entropy loss when SAE reconstructions are inserted into the model. In this work, we improve on these limitations by taking a fundamentally different approach: we use low-rank adaptation (LoRA) to finetune the language model itself around a previously trained SAE. We analyze our method across SAE sparsity, SAE width, language model size, LoRA rank, and model layer on the Gemma Scope family of SAEs. In these settings, our method reduces the cross entropy loss gap by 30% to 55% when SAEs are inserted during the forward pass. We also find that compared to end-to-end (e2e) SAEs, our approach achieves the same downstream cross entropy loss 3$ imes$ to 20$ imes$ faster on Gemma-2-2B and 2$ imes$ to 10$ imes$ faster on Llama-3.2-1B. We further show that our technique improves downstream metrics and can adapt multiple SAEs at once. Our results demonstrate that improving model interpretability is not limited to post-hoc SAE training; Pareto improvements can also be achieved by directly optimizing the model itself.