TopK Transcoder — Gemma 2 2B, layer 12 MLP

An architecture sketch of how a transcoder would fit into our existing pipeline. We have not trained one yet — this page is documentation, parallel to sae_architecture.py. The transcoder occupies the same parametric shape as our SAE; what changes is what its decoder is trained to predict.

activation in/out real MLP (ground truth) transcoder path centered input encoder / decoder ReLU sparse code (TopK) loss constraint op metric

Forward pass: residual in, MLP output predicted

x — pre-MLP residual stream
[B, 2304] fp32
one vector per token, taken after LayerNorm but before the MLP (same input the real MLP would receive)
path splits two ways
PATH A: real MLP (ground truth)
x → up + gate (9216) → GeGLU → down → mlp_out
the MLP we're trying to approximate. Shape stays [B, 2304] going in and coming out, but the work happens in a 9216-dim hidden space we can't directly read.
PATH B: transcoder (our approximation)
reads the same x, produces a prediction of mlp_out via a sparse intermediate code. Everything below this row is path B.
subtract decoder bias (pre-bias trick, same as SAE)
x − b_dec
[B, 2304]
b_dec is initialized to the dataset mean of MLP outputs (not of x). The encoder learns deviations from that mean.
@ W_encᵀ + b_enc (encoder)
z_pre
[B, 18,432]
W_enc shape = (18,432, 2304); rows are feature detectors reading the pre-MLP residual
ReLU
z_relu
[B, 18,432]
nonneg pre-activations
TopK along feature dim (k = 32)
z — sparse code
[B, 18,432], L0 = 32
keep the 32 largest per row, zero the rest. These k features are the transcoder's claim about what the MLP is computing on this token.
@ W_dec (decoder)
z @ W_dec
[B, 2304]
W_dec shape = (18,432, 2304); row i is the direction in MLP-output space contributed by feature i. Decoder output lives in the MLP's output space, not the encoder's input space.
+ b_dec
mlp_out_hat — predicted MLP output
[B, 2304] fp32
loss = ((mlp_out − mlp_out_hat)²).mean(). Compared against PATH A's actual mlp_out.

SAE vs Transcoder — what actually differs

propertySAE (what we trained)Transcoder (what we'd train)
trained againstx (residual stream reconstruction)MLP(x) (function approximation)
encoder inputresidual stream at site Sresidual stream just before MLP
decoder outputsame residual stream at site Sthe MLP's output at the same layer
can replace anything in the forward pass?no — purely a side hookyes — swap the MLP at inference, model still mostly works
feature semantics“this concept is present in the stream here”“this is a thing the MLP computes and writes to the stream”
good forreading the residual stream at one pointtracing computation flow across layers
circuit tracing usable?indirect; need extra inference to connect layersdirect; features compose into a feature-to-feature graph

Inference-time view: swap the MLP for the transcoder

x — pre-MLP residual
[B, 2304]
same as before
during inference, route through the transcoder instead of the MLP
Transcoder.forward(x)
encoder → TopK → decoder. Produces mlp_out_hat in one shot.
use mlp_out_hat where mlp_out would have gone
mlp_out_hat → + residual_skip → resid_post
[B, 2304]
the rest of the network sees a residual stream that's approximately what the MLP would have produced — but now we know exactly which k features contributed and by how much.

One training step

sample batch of (x, mlp_out) pairs
[B, 2304], [B, 2304]
we'd need a harvest pass that saves both pre-MLP residual and the MLP's output at layer 12 — twice the storage of the SAE harvest.
forward pass → (mlp_out_hat, z)
see the top-level flow above
loss = mean((mlp_out − mlp_out_hat)²)
reconstruction target is the MLP's output, not the residual stream. This is the only structural difference from SAE training.
loss.backward()
remove radial component of W_dec.grad → opt.step → renormalize
exactly the same constraint machinery as the SAE — unit-norm decoder rows.
log: loss, NMSE-vs-mlp_out, L0, dead-fraction
NMSE is now against ‖mlp_out − mlp_out_mean‖², so an explained variance of 1.0 would mean we perfectly recovered the MLP's function from k features per token.

Conceptual notes

Why a transcoder is "the MLP, decomposed"

An SAE asks: what concepts are present in the residual stream at point S? It learns a sparse code for whatever vector lives there. It doesn't know or care where that vector came from.

A transcoder asks a sharper question: what is this specific MLP computing? Its input is the MLP's input, its output is the MLP's output, and its training loss is the mismatch between them. The k features that fire on any given token are the transcoder's claim about which k operations the MLP performed for that token.

Why this matters for circuit tracing

If a trained transcoder approximates the MLP well, you can swap it in at inference and the downstream network barely notices. Now every contribution the "MLP" makes to the residual stream is a sum of k = 32 named feature directions. That gives circuit tracing a place to stand: you can attribute a downstream feature's activation to specific upstream features by following the decoder weights, because the decoder weights are the model (under the swap).

With an SAE you can read what's in the stream but not who put it there. A transcoder-replaced layer makes the "who put it there" question answerable by inspection of W_dec.

What stays the same as the SAE

Everything else. Same TopK sparsity. Same unit-norm decoder constraint. Same pre-bias trick. Same AdamW. Same dead-feature concerns. From a training-code perspective, a transcoder is an SAE with one loss-target swap.

What changes in our pipeline if we want to train one

The activation harvest. Right now we save the layer-12 residual stream only (one vector per token). For a transcoder we'd need two tensors per token: the pre-MLP residual (input) and the MLP's output (target). That's 2× the disk footprint and a slightly more invasive hook setup, but otherwise the same harvest script.

Where transcoders sit in the lineage

Transcoders aren't a replacement for SAEs — they're a complementary tool for a different question. The field also has crosscoders (span multiple layers or models at once) and earlier dictionary-learning approaches. SAE / transcoder / crosscoder are three points on the same design space: sparse over-complete codes pointed at different signals in the network.

Config

fieldvalue
source modelgoogle/gemma-2-2b
sitemodel.model.layers[12].mlp (drop-in across the MLP)
encoder inputpre-MLP residual stream (LayerNorm'd, same as MLP's input)
decoder outputpredicts the MLP's output (post-GeGLU, pre-residual-add)
d_model2304
MLP hidden (GeGLU)9216
n_features18,432 (= 8 × d_model)
k (TopK)32
loss targetthe MLP's actual output, not the residual stream
lossMSE (∥mlp_out − mlp_out_hat∥²)
decoder norm constraint1.0 per row (same as SAE)
training datawould need fresh harvest of (pre-MLP, post-MLP) pairs at layer 12

Parameter counts

parametershapecount
W_enc(18,432, 2304)42.5 M
b_enc(18,432,)18.4 k
W_dec(18,432, 2304)42.5 M
b_dec(2304,)2.3 k
Total85.0 M

Identical to the SAE in parameter count. The architecture is the same shape — encoder + sparse latent + decoder. What changes is what we point the decoder at during training.