Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Compatibility for OLMo and OLMo2? #804

Open
1 task done
spaidataiga opened this issue Nov 28, 2024 · 2 comments
Open
1 task done

[Proposal] Compatibility for OLMo and OLMo2? #804

spaidataiga opened this issue Nov 28, 2024 · 2 comments
Labels
complexity-moderate Moderately complicated issues for people who have intermediate experience with the code model-request Any issues related to requesting additional model support

Comments

@spaidataiga
Copy link

spaidataiga commented Nov 28, 2024

Proposal

It would be nice to include OLMo (1B and 7B) and their checkpoints as available compatible models for HookedTransformer.

Motivation

OLMo-1B would be a great model to do some mechanistic interpretability, especially as it is very open-source, allowing us to see relations between training data/processes, checkpoints and model performance. It should have fairly similar architecture to already compatible models. If it is possible to get it to run already, I would really appreciate a link to some information, as I've tried to look through the documentation myself in the meantime.

Pitch

Add OLMo-1B, -7B. Add OLMo2-7B and -13B. Add model checkpoints?

Checklist

  • I have checked that there is no similar issue in the repo (required)
@bryce13950 bryce13950 added complexity-moderate Moderately complicated issues for people who have intermediate experience with the code model-request Any issues related to requesting additional model support labels Dec 9, 2024
@Neelectric
Copy link

Neelectric commented Dec 13, 2024

I would just like to express my enthusiastic endorsement of this proposal. I tried to the implementation a little bit, and thought I would share some of what that revealed. It seems to me that OLMo-1 and OLMo-2 follow Llama-2 quite closely. For example, in convert_hf_model_config() inside of loading_from_pretrained.py, something similar to

if official_model_name.startswith(
        ("olmo2-7b", "allenai/OLMo-2-1124-7B")
    ):  # same architecture for LLaMA and Llama-2
        cfg_dict = {
            "d_model": 4096,
            "d_head": 4096 // 32,
            "n_heads": 32,
            "d_mlp": 11008,
            "n_layers": 32,
            "n_ctx": 4096,
            "eps": 1e-6,
            "d_vocab": 100352,
            "act_fn": "silu",
            "normalization_type": "RMS",
            "positional_embedding_type": "rotary",
            "rotary_adjacent_pairs": False,
            "rotary_dim": 4096 // 32,
            "final_rms": True,
            "gated_mlp": True,
        }

might make sense? Then we would probably need a new pretrained/weight_conversions/olmo2.py file. A nuance here seems to be that when loaded with the newest version HuggingFace transformers (at the time of writing), an Olmo2ForCausalLM object looks like

Olmo2ForCausalLM(
  (model): Olmo2Model(
    (embed_tokens): Embedding(100352, 4096, padding_idx=100277)
    (layers): ModuleList(
      (0-31): 32 x Olmo2DecoderLayer(
        (self_attn): Olmo2SdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): Olmo2RotaryEmbedding()
          (q_norm): Olmo2RMSNorm((4096,), eps=1e-06)
          (k_norm): Olmo2RMSNorm((4096,), eps=1e-06)
        )
        (mlp): Olmo2MLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (post_attention_layernorm): Olmo2RMSNorm((4096,), eps=1e-06)
        (post_feedforward_layernorm): Olmo2RMSNorm((4096,), eps=1e-06)
      )
    )
    (norm): Olmo2RMSNorm((4096,), eps=1e-06)
  )
  (lm_head): Linear(in_features=4096, out_features=100352, bias=False)
)

Whereas a LlamaForCausalLM object from 'meta-llama/Llama-2-7b-hf' looks like

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
 

So for example, OLMo-2 has (post_attention_layernorm), (post_feedforward_layernorm), at every layer, as opposed to (input_layernorm), (post_attention_layernorm) for Llama-2. It also has additional (rotary_emb), (q_norm), (k_norm) in every self_attn module which Llama-2 does not, while missing the model-wide (rotary_emb) that Llama-2 has. Finally, there's the vocabulary size of 100352 in OLMo-2 vs 32000 in Llama-2. Finally, Olmo2RMSNorm and LlamaRMSNorm seem to both be equivalent to T5LayerNorm.

I'm tempted to give a PR a shot but I'm not sure if I know enough about TransformerLens. Is there anyone who could bridge the gap?

@Neelectric
Copy link

Actually, it looks like #718 as well as https://github.com/jonasrohw/TransformerLens/tree/OLMo are closely related

@jonasrohw jonasrohw mentioned this issue Dec 15, 2024
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
complexity-moderate Moderately complicated issues for people who have intermediate experience with the code model-request Any issues related to requesting additional model support
Projects
None yet
Development

No branches or pull requests

3 participants