TopK Sparse Autoencoder — Gemma 2 2B, layer 12

Trained on residual-stream activations from layer 12. The goal is to decompose each activation as a sparse sum of unit-norm dictionary directions; those directions are our candidate features.

activation in/out centered residual encoder / decoder ReLU sparse code (TopK) loss constraint op metric

Forward pass: activation → sparse code → reconstruction

x — residual stream activation
[B, 2304] fp32
one vector per token, harvested from layer 12 of Gemma 2 2B
subtract decoder bias (pre-bias trick)
x - b_dec
[B, 2304]
b_dec is initialized to the dataset mean — encoder learns deviations only
@ W_encᵀ + b_enc (encoder)
z_pre
[B, 18,432]
W_enc shape = (18,432, 2304); rows are feature detectors
ReLU
z_relu
[B, 18,432]
nonneg pre-activations — many small, some large
TopK along feature dim (k = 32)
z — sparse code
[B, 18,432], L0 = 32
keep the 32 largest per row, zero the rest. L0 (#active features) is exactly k for every example, every step.
@ W_dec (decoder)
z @ W_dec
[B, 2304]
W_dec shape = (18,432, 2304); row i is the unit-norm direction in residual space associated with feature i. x̂ - b_dec is a sparse sum of k of these directions.
+ b_dec
x̂ — reconstruction
[B, 2304] fp32
loss = ((x - x̂)²).mean()

One training step

sample batch from memmap
[B=4096, 2304]
seeded uniform sampling with replacement over 500k vectors
forward pass → (x̂, z)
see top-level flow above
loss = mean((x - x̂)²)
no L1, no auxiliary losses (yet). TopK gives sparsity for free.
loss.backward()
remove radial component of W_dec.grad
grad = grad − (grad · W_dec) ⊙ W_dec, per row.
Why: W_dec rows live on the unit sphere. The radial gradient component would be undone by the renormalize step anyway. Removing it makes Adam's running statistics describe only tangent (useful) updates.
opt.step() (AdamW)
renormalize: W_dec row-norms → 1
hard projection back to the unit sphere after every step
FeatureFireTracker.update(z)
per-feature counter: steps since this feature last had any nonzero activation in any example of the batch. Features with counter ≥ window are 'dead'.
every log_every steps
log: loss, NMSE, L0, dead-fraction → metrics.jsonl
NMSE = ‖x - x̂‖² / ‖x - x̄‖² (normalized MSE; trivial-mean predictor → 1.0).
explained variance = 1 - NMSE.
L0 should hover at k = 32.
dead-fraction is what we'll address with feature resampling once we have a baseline.

Conceptual notes

Why pre-bias?

If we encoded x directly, b_dec would have to absorb the dataset mean plus any reconstruction offset, and the encoder would burn capacity on the constant offset. Subtracting b_dec before encoding decouples the two: b_dec ends up at (approximately) the dataset mean, and the encoder/decoder pair learns the deviations.

Why unit-norm decoder rows?

The pair (W_enc, W_dec) has a scale gauge: doubling row i of W_dec and halving column i of W_enc gives the same model. Without a constraint, the decoder norms drift and L1/TopK comparisons across features become meaningless. Fixing decoder row norm = 1 puts each feature's "size" entirely in its activation magnitude z_i, which is what we want for interpretation.

Why TopK and not L1?

L1 gives you a knob (the sparsity coefficient) you have to tune, and the L0 (actual number of active features) is an emergent property that drifts. TopK directly sets L0 = k as an architectural constraint. No tuning, fewer dead-feature pathologies, training is faster to interpret.

What "dead features" are and why they matter

A dead feature is one whose pre-activation never enters the top-k of any example for a long window. Its encoder row stops receiving gradient (it's not in the active set) and its decoder row stops being used. Once dead, it stays dead unless we intervene. Standard fix: every N steps, find dead features and re-initialize their encoder/decoder rows to point at high-error directions in the current batch ("resampling"). Not implemented yet — we'll add it once we see the dead fraction.

Config

fieldvalue
source modelgoogle/gemma-2-2b
hook sitemodel.model.layers[12] (residual stream output)
d_model2304
n_features18,432 (= 8 × d_model)
k (TopK)32
expected L0= 32 by construction
decoder column norm1.0 (enforced after every step)
encoder inittied: W_enc ← W_dec at init (Anthropic recipe)
b_dec initdataset mean of activations
lossMSE per-element (no L1 penalty — sparsity comes from TopK)
optimizerAdamW, lr=3e-4, no weight decay
dataset500k Gemma-2-2b layer-12 activations from Pile-10k

Parameter counts

parametershapecount
W_enc(18,432, 2304)42.5 M
b_enc(18,432,)18.4 k
W_dec(18,432, 2304)42.5 M
b_dec(2304,)2.3 k
Total85.0 M

For comparison, the source model (Gemma 2 2B) has ~2.61 B parameters. The SAE is ~3% the size of the model it's analyzing — but it only has to fit one layer.