Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix llama ONNX export #1432

Merged
merged 2 commits into from
Oct 6, 2023
Merged

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Oct 5, 2023

LLama uses an optional configuration key num_key_value_heads for the number of heads, and uses num_attention_heads to compute the head dimension. This was unfortunately not implemented in #975 (apart from llama2 70b, the llama and llama2 series do not make use of this num_key_value_heads key), as the key probably did not exist at the time.

Fixes #1399

@fxmarty fxmarty requested review from regisss and mht-sharma October 5, 2023 15:11
@fxmarty fxmarty merged commit ba113e5 into huggingface:main Oct 6, 2023
40 of 52 checks passed
Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix

@@ -216,7 +216,48 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class LlamaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would make sense to move it in optimum/utils/input_generators.py

class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):

random_sequence_length_range=random_sequence_length_range,
**kwargs,
)
self.num_key_value_heads = normalized_config.num_key_value_heads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also add a fix in prepare_inputs_for_merged

(as done in #1425)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes thank you for taking care of it!

echarlaix added a commit that referenced this pull request Oct 6, 2023
echarlaix added a commit that referenced this pull request Oct 9, 2023
* Add ONNX export Mistral models support

* add test

* format

* fix format

* fix key _config

* tmp install transformers from source for tests

* change model id

* fix after #1432 merged

* fix

* format

* fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Export in ONNX/FP16 of PY007/TinyLlama-1.1B-Chat-v0.2 fails
2 participants