Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize weight conversions #852

Open
wants to merge 4 commits into
base: unified-conversions
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions transformer_lens/factories/mlp_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

Centralized location for creating any MLP needed within TransformerLens
"""
from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
from transformer_lens.components.mlps.gated_mlp import GatedMLP
from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit
from transformer_lens.components.mlps.mlp import MLP
from transformer_lens.components.mlps.moe import MoE
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


class MLPFactory:
@staticmethod
def create_mlp(cfg: HookedTransformerConfig) -> CanBeUsedAsMLP:
def create_mlp(cfg: HookedTransformerConfig):
from transformer_lens.components.mlps.gated_mlp import GatedMLP
from transformer_lens.components.mlps.gated_mlp_4bit import GatedMLP4Bit
from transformer_lens.components.mlps.mlp import MLP
from transformer_lens.components.mlps.moe import MoE

if cfg.num_experts:
return MoE(cfg)
elif cfg.gated_mlp:
Expand Down
8 changes: 5 additions & 3 deletions transformer_lens/factories/weight_conversion_factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.weight_conversion.bert import BertWeightConversion
from transformer_lens.weight_conversion.bloom import BloomWeightConversion
from transformer_lens.weight_conversion.conversion_utils.architecture_conversion import (
ArchitectureConversion,
)
from transformer_lens.weight_conversion.bert import BertWeightConversion
from transformer_lens.weight_conversion.bloom import BloomWeightConversion
from transformer_lens.weight_conversion.gemma import GemmaWeightConversion
from transformer_lens.weight_conversion.gpt2 import GPT2WeightConversion
from transformer_lens.weight_conversion.gpt2_lm_head_custom import GPT2LMHeadCustomWeightConversion
from transformer_lens.weight_conversion.gpt2_lm_head_custom import (
GPT2LMHeadCustomWeightConversion,
)
from transformer_lens.weight_conversion.gptj import GPTJWeightConversion
from transformer_lens.weight_conversion.mistral import MistralWeightConversion
from transformer_lens.weight_conversion.mixtral import MixtralWeightConversion
Expand Down
13 changes: 2 additions & 11 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import transformer_lens.utils as utils
from transformer_lens.factories import WeightConversionFactory
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.weight_conversion import (
convert_mingpt_weights,
convert_neel_solu_old_weights,
)

OFFICIAL_MODEL_NAMES = [
"gpt2",
Expand Down Expand Up @@ -1799,7 +1795,7 @@ def load_hugging_face_model(
token=huggingface_token,
**kwargs,
)

return hf_model


Expand Down Expand Up @@ -1837,14 +1833,9 @@ def get_pretrained_state_dict(
f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
)
kwargs["trust_remote_code"] = True


hf_model = load_hugging_face_model(
official_model_name,
cfg=cfg,
hf_model=hf_model,
dtype=dtype,
**kwargs
official_model_name, cfg=cfg, hf_model=hf_model, dtype=dtype, **kwargs
)

for param in hf_model.parameters():
Expand Down
11 changes: 0 additions & 11 deletions transformer_lens/weight_conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,8 @@
from .neo import convert_neo_weights
from .gpt2 import convert_gpt2_weights
from .opt import convert_opt_weights
from .gptj import convert_gptj_weights
from .neox import convert_neox_weights
from .llama import convert_llama_weights
from .bert import convert_bert_weights
from .mistral import convert_mistral_weights
from .mixtral import MixtralWeightConversion
from .bloom import convert_bloom_weights
from .coder import convert_coder_weights
from .qwen import convert_qwen_weights
from .qwen2 import convert_qwen2_weights
from .phi import convert_phi_weights
from .phi3 import convert_phi3_weights
from .gemma import convert_gemma_weights
from .mingpt import convert_mingpt_weights
from .nanogpt import convert_nanogpt_weights
from .t5 import convert_t5_weights
Expand Down
97 changes: 66 additions & 31 deletions transformer_lens/weight_conversion/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,71 @@
WeightConversionSet,
)


class BertWeightConversion(ArchitectureConversion):
def __init__(self, cfg: HookedTransformerConfig) -> None:
super().__init__({
"embed.embed.W_E": "bert.embeddings.word_embeddings.weight",
"embed.pos_embed.W_pos": "bert.embeddings.position_embeddings.weight",
"embed.token_type_embed.W_token_type": "bert.embeddings.token_type_embeddings.weight",
"embed.ln.w": "bert.embeddings.LayerNorm.weight",
"embed.ln.b": "bert.embeddings.LayerNorm.bias",
"mlm_head.W": "bert.cls.predictions.transform.dense.weight",
"mlm_head.b": "bert.cls.predictions.transform.dense.bias",
"mlm_head.ln.w": "bert.cls.predictions.transform.LayerNorm.weight",
"mlm_head.ln.b": "bert.cls.predictions.transform.LayerNorm.bias",
"mlm_head.W_U": "bert.embeddings.word_embeddings.weight.T",
"mlm_head.b_U": "bert.cls.predictions.bias",
"blocks": ("bert.encoder.layer", WeightConversionSet({
"attn.W_Q": ("attention.self.query.weight", RearrangeWeightConversion("(i h) m -> i m h", i=cfg.n_heads)),
"attn.b_Q": ("attention.self.query.bias", RearrangeWeightConversion("(i h) -> i h", i=cfg.n_heads)),
"attn.W_K": ("attention.self.key.weight", RearrangeWeightConversion("(i h) m -> i m h", i=cfg.n_heads)),
"attn.b_K": ("attention.self.key.bias", RearrangeWeightConversion("(i h) -> i h", i=cfg.n_heads)),
"attn.W_V": ("attention.self.value.weight", RearrangeWeightConversion("(i h) m -> i m h", i=cfg.n_heads)),
"attn.b_V": ("attention.self.value.bias", RearrangeWeightConversion("(i h) -> i h", i=cfg.n_heads)),
"attn.W_O": ("attention.self.dense.weight", RearrangeWeightConversion("m (i h) -> i h m", i=cfg.n_heads)),
"attn.b_O": "attention.output.dense.bias",
"ln1.w": "attention.output.LayerNorm.weight",
"ln1.b": "attention.output.LayerNorm.bias",
"mlp.W_in": ("intermediate.dense.weight", RearrangeWeightConversion("mlp model -> model mlp")),
"mlp.b_in": "intermediate.dense.bias",
"mlp.W_out": ("output.dense.weight", RearrangeWeightConversion("model mlp -> mlp model")),
"mlp.b_out": "output.dense.bias",
"ln2.w": "output.LayerNorm.weight",
"ln2.b": "output.LayerNorm.bias",
}))
})
super().__init__(
{
"embed.embed.W_E": "bert.embeddings.word_embeddings.weight",
"embed.pos_embed.W_pos": "bert.embeddings.position_embeddings.weight",
"embed.token_type_embed.W_token_type": "bert.embeddings.token_type_embeddings.weight",
"embed.ln.w": "bert.embeddings.LayerNorm.weight",
"embed.ln.b": "bert.embeddings.LayerNorm.bias",
"mlm_head.W": "cls.predictions.transform.dense.weight",
"mlm_head.b": "cls.predictions.transform.dense.bias",
"mlm_head.ln.w": "cls.predictions.transform.LayerNorm.weight",
"mlm_head.ln.b": "cls.predictions.transform.LayerNorm.bias",
"mlm_head.W_U": "bert.embeddings.word_embeddings.weight.T",
"mlm_head.b_U": "cls.predictions.bias",
"blocks": (
"bert.encoder.layer",
WeightConversionSet(
{
"attn.W_Q": (
"attention.self.query.weight",
RearrangeWeightConversion("(i h) m -> i m h", i=cfg.n_heads),
),
"attn.b_Q": (
"attention.self.query.bias",
RearrangeWeightConversion("(i h) -> i h", i=cfg.n_heads),
),
"attn.W_K": (
"attention.self.key.weight",
RearrangeWeightConversion("(i h) m -> i m h", i=cfg.n_heads),
),
"attn.b_K": (
"attention.self.key.bias",
RearrangeWeightConversion("(i h) -> i h", i=cfg.n_heads),
),
"attn.W_V": (
"attention.self.value.weight",
RearrangeWeightConversion("(i h) m -> i m h", i=cfg.n_heads),
),
"attn.b_V": (
"attention.self.value.bias",
RearrangeWeightConversion("(i h) -> i h", i=cfg.n_heads),
),
"attn.W_O": (
"attention.output.dense.weight",
RearrangeWeightConversion("m (i h) -> i h m", i=cfg.n_heads),
),
"attn.b_O": "attention.output.dense.bias",
"ln1.w": "attention.output.LayerNorm.weight",
"ln1.b": "attention.output.LayerNorm.bias",
"mlp.W_in": (
"intermediate.dense.weight",
RearrangeWeightConversion("mlp model -> model mlp"),
),
"mlp.b_in": "intermediate.dense.bias",
"mlp.W_out": (
"output.dense.weight",
RearrangeWeightConversion("model mlp -> mlp model"),
),
"mlp.b_out": "output.dense.bias",
"ln2.w": "output.LayerNorm.weight",
"ln2.b": "output.LayerNorm.bias",
}
),
),
}
)
132 changes: 75 additions & 57 deletions transformer_lens/weight_conversion/bloom.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.weight_conversion.conversion_utils import ArchitectureConversion
from transformer_lens.weight_conversion.conversion_utils.conversion_steps import (
CallableWeightConversion,
RearrangeWeightConversion,
WeightConversionSet,
CallableWeightConversion,
)


class BloomWeightConversion(ArchitectureConversion):
def __init__(self, cfg: HookedTransformerConfig) -> None:
super().__init__(
Expand All @@ -15,62 +16,79 @@ def __init__(self, cfg: HookedTransformerConfig) -> None:
"embed.ln.b": "transformer.word_embeddings_layernorm.bias",
"unembed.W_U": "lm_head.weight.T",
"ln_final.b": "transformer.ln_f.bias",
"blocks": ("transformer.h", WeightConversionSet({
"ln1.w": "input_layernorm.weight",
"ln1.b": "input_layernorm.bias",
"attn.W_Q": (
"self_attention.query_key_value.weight.T",
RearrangeWeightConversion(
"m n h ->n m h",
input_filter=lambda weight: weight.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)[..., 0, :],
n=cfg.n_heads,
)
),
"attn.W_K": (
"self_attention.query_key_value.weight.T",
RearrangeWeightConversion(
"m n h ->n m h",
input_filter=lambda weight: weight.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)[..., 1, :],
n=cfg.n_heads,
)
),
"attn.W_V": (
"self_attention.query_key_value.weight.T",
RearrangeWeightConversion(
"m n h ->n m h",
input_filter=lambda weight: weight.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)[..., 2, :],
n=cfg.n_heads,
)
),
"attn.b_Q": (
"self_attention.query_key_value.bias",
CallableWeightConversion(
convert_callable=lambda weight: weight.reshape(cfg.n_heads, 3, cfg.d_head)[..., 0, :],
)
),
"attn.b_K": (
"self_attention.query_key_value.bias",
CallableWeightConversion(
convert_callable=lambda weight: weight.reshape(cfg.n_heads, 3, cfg.d_head)[..., 1, :],
)
),
"attn.b_V": (
"self_attention.query_key_value.bias",
CallableWeightConversion(
convert_callable=lambda weight: weight.reshape(cfg.n_heads, 3, cfg.d_head)[..., 2, :],
)
),
"attn.W_O": (
"self_attention.dense.weight.T",
RearrangeWeightConversion("(n h) m->n h m", n=cfg.n_heads)
"blocks": (
"transformer.h",
WeightConversionSet(
{
"ln1.w": "input_layernorm.weight",
"ln1.b": "input_layernorm.bias",
"attn.W_Q": (
"self_attention.query_key_value.weight.T",
RearrangeWeightConversion(
"m n h -> n m h",
input_filter=lambda weight: weight.reshape(
cfg.d_model, cfg.n_heads, 3, cfg.d_head
)[..., 0, :],
n=cfg.n_heads,
),
),
"attn.W_K": (
"self_attention.query_key_value.weight.T",
RearrangeWeightConversion(
"m n h -> n m h",
input_filter=lambda weight: weight.reshape(
cfg.d_model, cfg.n_heads, 3, cfg.d_head
)[..., 1, :],
n=cfg.n_heads,
),
),
"attn.W_V": (
"self_attention.query_key_value.weight.T",
RearrangeWeightConversion(
"m n h -> n m h",
input_filter=lambda weight: weight.reshape(
cfg.d_model, cfg.n_heads, 3, cfg.d_head
)[..., 2, :],
n=cfg.n_heads,
),
),
"attn.b_Q": (
"self_attention.query_key_value.bias",
CallableWeightConversion(
convert_callable=lambda weight: weight.reshape(
cfg.n_heads, 3, cfg.d_head
)[..., 0, :],
),
),
"attn.b_K": (
"self_attention.query_key_value.bias",
CallableWeightConversion(
convert_callable=lambda weight: weight.reshape(
cfg.n_heads, 3, cfg.d_head
)[..., 1, :],
),
),
"attn.b_V": (
"self_attention.query_key_value.bias",
CallableWeightConversion(
convert_callable=lambda weight: weight.reshape(
cfg.n_heads, 3, cfg.d_head
)[..., 2, :],
),
),
"attn.W_O": (
"self_attention.dense.weight.T",
RearrangeWeightConversion("(n h) m -> n h m", n=cfg.n_heads),
),
"attn.b_O": "self_attention.dense.bias",
"ln2.w": "post_attention_layernorm.weight",
"ln2.b": "post_attention_layernorm.bias",
"mlp.W_in": "mlp.dense_h_to_4h.weight.T",
"mlp.b_in": "mlp.dense_h_to_4h.bias",
"mlp.W_out": "mlp.dense_4h_to_h.weight.T",
"mlp.b_out": "mlp.dense_4h_to_h.bias",
}
),
"attn.b_O": "self_attention.dense.bias",
"ln2.w": "post_attention_layernorm.weight",
"ln2.b": "post_attention_layernorm.bias",
"mlp.W_in": "mlp.dense_h_to_4h.weight.T",
"mlp.b_in": "mlp.dense_h_to_4h.bias",
"mlp.W_out": "mlp.dense_4h_to_h.weight.T",
"mlp.b_out": "mlp.dense_4h_to_h.bias",
}))
),
}
)
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .architecture_conversion import ArchitectureConversion
from .weight_conversion_utils import WeightConversionUtils
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from torch import nn

from .conversion_steps.base_weight_conversion import FIELD_SET
from .conversion_steps.weight_conversion_set import WeightConversionSet
from .weight_conversion_utils import WeightConversionUtils


class ArchitectureConversion:
Expand All @@ -9,3 +11,6 @@ def __init__(self, fields: FIELD_SET) -> None:

def convert(self, remote_module: nn.Module):
return self.field_set.convert(input_value=remote_module)

def __repr__(self) -> str:
return WeightConversionUtils.create_conversion_string(self.field_set.weights)
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ def handle_conversion(self, input_value):
return input_value * self.value
case OperationTypes.DIVISION:
return input_value / self.value

def __repr__(self):
return f"Is the following arithmetic operation: {self.operation} and value: {self.value}"
Loading