Modular Training of Neural Networks aids Interpretability

📅 2025-02-04
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Neural networks suffer from poor interpretability and high functional coupling. To address this, we propose an end-to-end modular training framework. Our core innovation is the first differentiable cluster separability metric, which underpins a novel cluster separability loss function; this explicitly incorporates spectral graph clustering principles into the training objective, guiding the network to spontaneously develop non-coupled, fine-grained, and functionally specialized subcircuits. Unlike post-hoc interpretability methods, our approach is architecture-agnostic and supports CNNs, Transformers, and large language models (LLMs). Empirical evaluation on MNIST, CIFAR, modular addition, and LLMs demonstrates significantly improved subcircuit separation and functional specificity, yielding simpler, more interpretable computational pathways. This work establishes a new paradigm for automated circuit discovery and interpretable AI.

Technology Category

Application Category

📝 Abstract
An approach to improve neural network interpretability is via clusterability, i.e., splitting a model into disjoint clusters that can be studied independently. We define a measure for clusterability and show that pre-trained models form highly enmeshed clusters via spectral graph clustering. We thus train models to be more modular using a ``clusterability loss'' function that encourages the formation of non-interacting clusters. Using automated interpretability techniques, we show that our method can help train models that are more modular and learn different, disjoint, and smaller circuits. We investigate CNNs trained on MNIST and CIFAR, small transformers trained on modular addition, and language models. Our approach provides a promising direction for training neural networks that learn simpler functions and are easier to interpret.
Problem

Research questions and friction points this paper is trying to address.

Enhancing neural network interpretability through modular training
Promoting non-interacting clusters using clusterability loss
Training simpler, disjoint circuit models for better analysis
Innovation

Methods, ideas, or system contributions that make the work stand out.

Modular neural network training
Clusterability loss function
Spectral graph clustering
🔎 Similar Papers
No similar papers found.