Trained on residual-stream activations from layer 12. The goal is to decompose each activation as a sparse sum of unit-norm dictionary directions; those directions are our candidate features.
If we encoded x directly, b_dec would have to absorb the dataset
mean plus any reconstruction offset, and the encoder would burn capacity on the
constant offset. Subtracting b_dec before encoding decouples the two:
b_dec ends up at (approximately) the dataset mean, and the encoder/decoder
pair learns the deviations.
The pair (W_enc, W_dec) has a scale gauge: doubling row i of W_dec and
halving column i of W_enc gives the same model. Without a constraint, the
decoder norms drift and L1/TopK comparisons across features become meaningless.
Fixing decoder row norm = 1 puts each feature's "size" entirely in its activation
magnitude z_i, which is what we want for interpretation.
L1 gives you a knob (the sparsity coefficient) you have to tune, and the L0 (actual number of active features) is an emergent property that drifts. TopK directly sets L0 = k as an architectural constraint. No tuning, fewer dead-feature pathologies, training is faster to interpret.
A dead feature is one whose pre-activation never enters the top-k of any example for a long window. Its encoder row stops receiving gradient (it's not in the active set) and its decoder row stops being used. Once dead, it stays dead unless we intervene. Standard fix: every N steps, find dead features and re-initialize their encoder/decoder rows to point at high-error directions in the current batch ("resampling"). Not implemented yet — we'll add it once we see the dead fraction.
| field | value |
|---|---|
| source model | google/gemma-2-2b |
| hook site | model.model.layers[12] (residual stream output) |
| d_model | 2304 |
| n_features | 18,432 (= 8 × d_model) |
| k (TopK) | 32 |
| expected L0 | = 32 by construction |
| decoder column norm | 1.0 (enforced after every step) |
| encoder init | tied: W_enc ← W_dec at init (Anthropic recipe) |
| b_dec init | dataset mean of activations |
| loss | MSE per-element (no L1 penalty — sparsity comes from TopK) |
| optimizer | AdamW, lr=3e-4, no weight decay |
| dataset | 500k Gemma-2-2b layer-12 activations from Pile-10k |
| parameter | shape | count |
|---|---|---|
| W_enc | (18,432, 2304) | 42.5 M |
| b_enc | (18,432,) | 18.4 k |
| W_dec | (18,432, 2304) | 42.5 M |
| b_dec | (2304,) | 2.3 k |
| Total | 85.0 M |
For comparison, the source model (Gemma 2 2B) has ~2.61 B parameters. The SAE is ~3% the size of the model it's analyzing — but it only has to fit one layer.