Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Help][BUG] KeyError: 'lm_head.weight' on loading llama 3.2 #1920

Open
steveepreston opened this issue Oct 13, 2024 · 5 comments
Open

[Help][BUG] KeyError: 'lm_head.weight' on loading llama 3.2 #1920

steveepreston opened this issue Oct 13, 2024 · 5 comments
Labels
Gemma Gemma model specific issues

Comments

@steveepreston
Copy link

steveepreston commented Oct 13, 2024

Trying to load llama-3.2 on TPU VM v3-8 via this:

device_mesh = keras.distribution.DeviceMesh((1, 8), ["batch", "model"], devices=keras.distribution.list_devices())
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ("model", None)
layout_map["decoder_block.*attention.*(query|key|value)/kernel"] = ("model", None, None)
layout_map["decoder_block.*attention_output/kernel"] = ("model", None, None)
layout_map["decoder_block.*ffw_gating.*/kernel"] = (None, "model")
layout_map["decoder_block.*ffw_linear/kernel"] = ("model", None)
model_parallel = keras.distribution.ModelParallel(layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)


model = keras_nlp.models.Llama3CausalLM.from_preset("meta-llama/Llama-3.2-3B-Instruct")

but it throws this Error:

KeyError: 'lm_head.weight'

note: i get layout_map code from This Example. i don't know if problem is from layout_map or Llama3CausalLM

@github-actions github-actions bot added the Gemma Gemma model specific issues label Oct 13, 2024
@steveepreston steveepreston changed the title [BUG] KeyError: 'lm_head.weight' on loading llama 3.2 [Help][BUG] KeyError: 'lm_head.weight' on loading llama 3.2 Oct 13, 2024
@Gopi-Uppari
Copy link

Hi @steveepreston,

I able to execute the code using the Gemma model, and it worked without any issues. For the Llama model, however, could you please reach out to the Llama team for further assistance? Please refer to the Gist file for more details.

Thank you.

@steveepreston
Copy link
Author

Thank you for attention @Gopi-Uppari

Yes, gemma successfully executed in my test too. (although gemma-2-9b-it thrown OOM on TPU).
Problem is about llama model.

ok, i will try to create another issue there also.

@Gopi-Uppari
Copy link

Could you please confirm if this issue is resolved for you with the above comment ? Please feel free to close the issue if it is resolved ?

Thank you.

@steveepreston
Copy link
Author

Problem not resolved and I've moved to PyTorch.
Maybe I'll back to follow and solve this in future. There is still no example for Llama3CausalLM+XLA in the web.

@SamanehSaadat
Copy link
Member

Hi @steveepreston

Variable paths in Llama are different from Gemma so the layout map that works for Gemma doesn't work for Llama (see here).

Recently, get_layout_map is added for Llama here. So instead of specifying the layout map manually, you can use this function: layout_map = keras_nlp.models.LlamaBackbone.get_layout_map(device_mesh).

We haven't added get_layout_map for Llama3 yet but if it has the same architecture as Llama, you can copy the layout map from here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

No branches or pull requests

3 participants