-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ONNX and ORT support for Falcon #1391
Conversation
Fixes #1172 |
# we need to set output_attentions=True in the model input to avoid calling | ||
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would move this comment inside the method.
generation_config=generation_config, | ||
**kwargs, | ||
) | ||
# self.num_kv_heads = config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it for now
@@ -211,7 +211,7 @@ class NormalizedConfigManager: | |||
"blenderbot": BartLikeNormalizedTextConfig, | |||
"blenderbot_small": BartLikeNormalizedTextConfig, | |||
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"), | |||
"falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"), | |||
"falcon": NormalizedTextConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: does NormalizedConfig
have a NUM_KV_HEADS
attribute to normalize it or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No
This one was more painful than it should have been because:
position_ids
input and then usinggenerate
is bugged because the generated position_ids https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/models/falcon/modeling_falcon.py#L932 have a different shape than theposition_ids
generated ingenerate
. I believe this is a bug in many Transformers models.normalized_config
- this should be refactored with inheritance to avoid any controlflow at all.Trilu
op onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Trilu(14) node with name '/decoder/Trilu' microsoft/onnxruntime#16189Remaining issue: I think
repeat_interleave
ONNX export insertsLoop
in the ONNX, which we may want to avoid. EDIT: fixed in pytorch 2.1