[Proposal] Allow tied embeddings #671
Labels
complexity-moderate
Moderately complicated issues for people who have intermediate experience with the code
enhancement
New feature or request
Proposal
TransformerLens assumes all models have untied embeddings (ie W_U =/= W_E.T). This is good to assume in general, and needs to be true if LN is folded. But, it is more memory expensive.
This is particularly bad for Gemma models, which have tied embeddings and a very large vocab size, eg 25% of Gemma 2 2.6B's params is W_E, and 10% of Gemma 2 9B is W_E. I think it would be great to load the tied models by default with tied embeddings (so W_U.data = W_E.data.T), but a helper function to clone the matrix and make this untied if need be. This would involve adding a field for tied_embeddings to the Config which defaults to False, but can be set to True for select models like GPT-2 and Gemma and Gemma 2, but which gets set back to False if fold_layernorm is run.
I'd love people to be able to work with the Gemma 2 models with a bunch of SAEs in memory, so memory efficiency is important (and folding LayerNorm isn't that important)
The text was updated successfully, but these errors were encountered: