Decomposing Monosemanticity using Dictionary Learning Visualized
A visualized explanation of the paper Towards Monosemanticity - Decomposing Language Models With Dictionary Learning
It describes the results of a toy model they've trained to understand what's going on inside a neural network. However, if you don't have the ML background, you probably don't understand anything. We're going to explain this paper as simply as possible.
The results of our breakdown -
The best way to understand what's going on here is to understand the motivations of the Mechanistic Interpretability team at Anthropic.
ML systems in general are black boxes - we humans don't understand what's going on inside of them.
In theory, ML is just statistics. It's just curve fitting with a huge number of parameters. If you've ever taken a statistics class and have done some sort of regression, that's all ML is. The only difference is it's done at a way larger scale. Architectures like neural nets, convolutional nets, transformers, etc allow us to more efficiently work with these large number of parameters. BUT always remember, ML is just curve fitting.
The question is, how do statistics get us these insane abilities? How do statistics allow us to create a model that chats back to us as if it were human?
We don't know.
The goal of Mechanistic Interpretability is to understand what's going on here. To understand all of the details of this black box. To reason about the behavior of the whole system.
How does each component in this ML system interact?
Before we talk about anything, what does the title of this paper mean? Towards Monosemanticity: Decomposing Language Models with Dictionary Learning??? Let's break it down.
Let's start with monosemanticity. Monosemanticity means having a single, unambiguous meaning. For example, "chair" always refers to an object you sit on. These are just words that have a precise meaning.
On the other hand, polysemanticity refers to something that has multiple meanings. There is some ambiguity as to what is meant when this word is referred to. If you have to look at context, it's polysemantic.
Next, decomposing language models. This refers to analyzing and breaking down large, complex AI models to understand how they work internally.
Finally, with dictionary learning. Dictionary learning is a method in ML that is used to find representations in data. It's typically useful in sparse data - data where most of the values are zero but a rare value isn't. i.e [0,0,0,0,0,0,0,0,1,0,0,0,0]
Putting all of this together, the goal of this paper is to break down language models with an ML technique called dictionary learning with the motivation of monosemanticity - getting a precise understanding of parts of the neural network.
Paper Background & Abstract
The precise architecture for extracting features is a sparse autoencoder. A sparse autoencoder is a neural network architecture that is designed to learn compressed representations of data. It does this by enforcing sparsity.
Features are just meaningful attributes of data that are fed into the model.
The paper uses a one-layer transformer because that is the simplest possible unit to understand these models. There are a lot of unresolved issues that arise when scaling up these models. A one-layer transformer is the best way to start the interpretability of language models.
Another note is that a neuron in a neural network is NOT the most neural computation unit of human understanding. Fundamentally neurons are polysemantic. More than that, it's difficult to reason about the behavior of a network in terms of individual neurons.
In another Anthropic paper, superposition, they introduce a cause of polysemanticity. Superposition summarizes that a neural network can represent more features of data than it has neurons. It does this by assigning a linear combination of neurons.
The best way to find sparse and interpretable features is to create models without superposition, and using dictionary learning.
The goal of this paper is to demonstrate sparse autoencoders succeeding at extracting interpretability features.
The architecture is a one-layer transformer with a 512-neuron MLP layer. The model is then trained with 8 billion data points.
Summary of Results
These results are directly from the paper.
- Sparse autoencoders extract relatively monosemantic features.
- Sparse autoencoders produce interpretable features that are effectively invisible in the neuron basis.
- Sparse autoencoder features can be used to intervene in and steer transformer generation.
- Sparse autoencoders produce relatively universal features.
- Features appear to "split" as we increase autoencoder size.
- Just 512 neurons can represent tens of thousands of features.
- Features connect in "finite-state automata" like systems that implement complex behaviors.
The biggest barrier to understanding neural networks is that as the model grows, the combinatorics of features learned within neurons grow.
The whole goal of the one-layer transformer is to work around this. It's the simplest possible architecture when building a language model.
The goal of decomposition is to turn activations in a network into features. Moreover, following the work of superposition, we want to have more features than we do neurons.
A neuron activation can be thought of as a neuron receiving sufficient input that it triggers an action potential.
But what makes a good decomposition?
- We can interpret the conditions under which each feature is active.
- We can interpret the downstream effects of each feature.
- The features explain a significant portion of the functionality of the MLP layer.
Again, decomposing this model is just the start. We want to understand how they work in the system as a whole.
An approach to decomposition that could be tried is just trying to remove superposition as a whole. However, the Anthropic team concluded that such a technique would not result in clearly interpretable neurons.
The best technique for solving such a problem is sparse autoencoders.
Although we only understand very little about features in transformers, we do know that they are suggestive of highly sparse variables.
We seek decomposition which is sparse and over-complete. A problem perfect for sparse dictionary learning. Overcomplete refers to models that have more parameters than the minimum required to represent the data.
Sparse autoencodes are also great because they can scale to very large datasets. This is necessary to get all the features we desire.
The autoencoder has a bias as input, a linear layer with bias and ReLU for the encoder, and another linear layer and bias for the decoder.
The bias terms are very useful.
Adam optimizer is also used to reconstruct the MLP activations.
Scale matters a lot in this example. There are 8 billion training parameters.
The best way to tell that the setup is working is to look at the validation/test loss of the ML model. After trying other techniques, this is what the Anthropic team defaulted to.
Going back to the one-layer model, it has a lot of advantages in our case:
- They have a lot less "true features" than larger models. Moreover, they are easier to train and experiment with.
- We can overtain models cheaply.
- We can carefully analyze the effects features have on the logit outputs.
Detailed Investigations of Individual Features
The most important claim of this paper is that dictionary learning can extract features that are significantly more monosemantic than neurons.
The learned feature that is most active in each context is usually the most important.
The specificity of each feature is really important. That's what it takes to rule our polysemanticity. Again, with polysemanticity neurons can activate for unrelated concepts.
To measure the activation of a token the estimate log-likelihood ratio is used. Large feature activations have large impacts on model predictions, so getting their interpretation matters most.
The plot of the distribution of feature activations weighted by activation level is also important.
Activation sensitivity can be measured using the Pearson correlation and by taking the magnitude into account.
Something noted is that the autoencoder is trained on model activations. This could be harmful because the feature set learns could in theory represent structure in the training data alone, without any relevance to the network's function.
An important feature of a single neuron is whether it's universal. We want to make sure that the feature is generally identifiable across tasks and datasets.
Using either human recognition or LLMs, features are more interpretable than neurons. They're quite interpretable overall.
The absolute best way to measure the interpretability of a model is human judgment. A blinded annotator scores features and neurons based on how interpretable they are.
To avoid "overfitting," and other data bias problems features are drawn randomly from a sample.
The problem with human judgment is it's labor intensive. There are a ton more features than neurons. Moreover, imagine a scenario where this is heavily scaled up.
Another way to analyze interpretability is to use LLMs. Just like how a human would analyze features, an LLM provides their input.
The results show that Claude is able to explain and predict activations for features significantly better than for neurons.
LLMs can also be used in a fashion where the explanations of features generated in the previous analysis are used to predict unseen logits as well as what's going to come next.
A last way to gauge interpretability is activation interval analysis. Rather than asking if a feature seems interpretable, we ask whether a range of activations is consistent with the overall hypothesis suggested by all the feature's activations.
Think of it in terms of feature activation strength. High-activating features are the most consistent.
There are some problems with gauging interpretability. All these methods listed above have some caveats.
Firstly, the feature activations are skewed towards the lower intervals.
Secondly, the features are sampled uniformly and interpretability might be correlated with importance.
Even with these caveats, these models still explain a lot. The A/1 run has 79% of the log-likelihood reduction by the MLP layer recovered by features.
Specific interpretable features are used by the model in interpretable ways.
It's also important to understand what the features reflect. Do they tell us about the model or about the data?
The model's activities reflect two things: the distribution of the dataset and the way that distribution is transformed by the model.
To assess interpretability of the downstream feature effects, we use:
- Logit weight inspection
- Feature ablation
- Pinned feature sampling
These all provide evidence of features found being used by the model.
Ultimately the goal of Mechanistic Interpretablity is to understand neural networks.
The features found vary enormously, and no concise summary will capture their breadth. It's best to focus on abstract properties and patterns noticed.
A very common pattern is the prevalence of context features and token-in-context features.
Features are often connected by feature splitting - pure context features or tokens in dictionaries with few learned features. It's then split into token-in-context features as more features are learned.
Moreover, most of the features in a one-layer model can be interpreted as an action feature. A feature that incorporates information about behavior.
A striking thing about features in general is that they appear in clusters. More features are developed in similar clusters as the total number of learned sparse features increases. This is referred to as feature splitting.
The theory is that "true features" are clustered into sets of similar features, which the model puts in very tight superposition.
In general, for dictionary learning, the correct number of features is less important than it might initially seem.
One of the biggest questions is whether features are universal. Do the same features from across different models?
Universality is hard because that indicates the same features learned in one model should generalize and be learned in another. However, showing this indicates that the features are reproducible.
In this paper, substantial universality is observed in all the experiments.
The way universality is observed is by comparing the features between two one-layer transformers. Think of each initialization as randomly seeded for each of the two models.
The two ways to judge the generality are to measure the activation similarity between two features and the downstream effects
Going back to clusters of related info, between two models there is an assembly of features. They are formed by one feature increasing the probability of tokens.
The biggest correlation to the work here is Anthropic's previous work on Toy Models and Superposition.
There are many works that suggest superposition is real. Moreover, there are mathematical frameworks for thinking about when and why superposition occurs.
Small datasets are memorized by superposition, instead of generalizing to features in the case of large datasets.
There is also a lot of work on disentanglement where a model learns representations of different dimensions that correspond to explanatory factors of the data. This is more of an engineering/efficiency problem.
In general, disentanglement explains dimensions of the model whereas superposition predicts that disentanglement hurts performance.
At a minimum, this paper shows that features seem to clump together in higher-density groups of related features. Features also need to be one-dimensional objects.
A very common pattern in this work is token-in-context features - features that merge via feature splitting with increasing dictionary sizer.
To continue work on this, here is a list of future work from the Anthropic team:
- Sparse auto-encoder scaling
- Scaling laws for dictionary learning
- How can we recognize good features?
- Scalability of Analysis
- Algorithmic improvements for sparse autoencoders
- Attentional superposition
- Theory of superposition and features