Skip to content

Commit

Permalink
Correct get tensor name for stateful key, values (huggingface#874)
Browse files Browse the repository at this point in the history
  • Loading branch information
praasz authored Aug 22, 2024
1 parent 32d193d commit c177040
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
"""

key_value_input_names = [
key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name for key_name in key.get_names())
key_name for key in ov_model.inputs for key_name in key.get_names() if "key_values" in key_name
]
key_value_output_names = [
key.get_any_name() for key in ov_model.outputs if any("present" in key_name for key_name in key.get_names())
key_name for key in ov_model.outputs for key_name in key.get_names() if "present" in key_name
]
not_kv_inputs = [
input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())
Expand Down

0 comments on commit c177040

Please sign in to comment.