Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Variable name #476

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Fix: Variable name #476

wants to merge 1 commit into from

Conversation

dhdbsrlw
Copy link

Hello.

I think variable's name should be changed in 'lit_llama/utils.py.'

When I implement your code, the existing code causes error. Then I debugged with printing.

There was no key named 'transformer.wte.weight' in checkpoint.

Thank you.

@Borda Borda requested review from rasbt and removed request for carmocca August 22, 2024 14:45
@@ -28,7 +28,7 @@ def llama_model_lookup(checkpoint: dict) -> str:

Checks the width of the lm_head.weight matrix, as these uniquely identify the model.
"""
embedding_size = checkpoint['transformer.wte.weight'].shape[1]
embedding_size = checkpoint['lm_head.weight'].shape[1]
Copy link
Contributor

@rasbt rasbt Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original key should exist for the original Llama weights. Maybe you downloaded the weights from somewhere else? Perhaps the following would be a good compromise:

Suggested change
embedding_size = checkpoint['lm_head.weight'].shape[1]
if 'transformer.wte.weight' in checkpoint:
embedding_size = checkpoint['transformer.wte.weight'].shape[1]
elif 'lm_head.weight' in checkpoint:
embedding_size = checkpoint['lm_head.weight'].shape[1]
else:
raise ValueError("Neither 'transformer.wte.weight' nor 'lm_head.weight' found in the checkpoint")

Copy link
Contributor

@rasbt rasbt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't reproduce it at the moment but perhaps we can have a compromise here that remains backwards compatible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants