-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix llama ONNX export #1432
Fix llama ONNX export #1432
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the fix
@@ -216,7 +216,48 @@ class OPTOnnxConfig(TextDecoderOnnxConfig): | |||
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig | |||
|
|||
|
|||
class LlamaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would make sense to move it in optimum/utils/input_generators.py
optimum/optimum/utils/input_generators.py
Line 830 in 099cd73
class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): |
random_sequence_length_range=random_sequence_length_range, | ||
**kwargs, | ||
) | ||
self.num_key_value_heads = normalized_config.num_key_value_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also add a fix in prepare_inputs_for_merged
(as done in #1425)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes thank you for taking care of it!
* Add ONNX export Mistral models support * add test * format * fix format * fix key _config * tmp install transformers from source for tests * change model id * fix after #1432 merged * fix * format * fix
LLama uses an optional configuration key
num_key_value_heads
for the number of heads, and usesnum_attention_heads
to compute the head dimension. This was unfortunately not implemented in #975 (apart from llama2 70b, the llama and llama2 series do not make use of this num_key_value_heads key), as the key probably did not exist at the time.Fixes #1399