Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 5, 2023
1 parent 1a972b3 commit 899db77
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,32 +125,32 @@ def post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder-only was exported separately without/with past
if self.use_past is True and len(models_and_onnx_configs) == 2:
decoder_path = Path(path, onnx_files_subpaths[0])
decoder_with_past_path = Path(path, onnx_files_subpaths[1])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
merge_decoders(
decoder=decoder_path,
decoder_with_past=decoder_with_past_path,
save_path=decoder_merged_path,
)
except Exception as e:
raise Exception(f"Unable to merge decoders. Detailed error: {e}")

# In order to do the validation of the two branches on the same file
onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name]

# We validate the two branches of the decoder model then
models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False

# Past key values won't be generated by default, but added in the input
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True
# # Attempt to merge only if the decoder-only was exported separately without/with past
# if self.use_past is True and len(models_and_onnx_configs) == 2:
# decoder_path = Path(path, onnx_files_subpaths[0])
# decoder_with_past_path = Path(path, onnx_files_subpaths[1])
# decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
# try:
# merge_decoders(
# decoder=decoder_path,
# decoder_with_past=decoder_with_past_path,
# save_path=decoder_merged_path,
# )
# except Exception as e:
# raise Exception(f"Unable to merge decoders. Detailed error: {e}")

# # In order to do the validation of the two branches on the same file
# onnx_files_subpaths = [decoder_merged_path.name, decoder_merged_path.name]

# # We validate the two branches of the decoder model then
# models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
# models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False

# # Past key values won't be generated by default, but added in the input
# models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

# models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
# models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True

return models_and_onnx_configs, onnx_files_subpaths

Expand Down

0 comments on commit 899db77

Please sign in to comment.