Skip to content

Commit

Permalink
Avoid loading onnx file in weight deduplication if not necessary (#1648)
Browse files Browse the repository at this point in the history
avoid loading onnx if not necessary
  • Loading branch information
fxmarty authored Jan 17, 2024
1 parent 130197f commit 3679079
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,13 +539,14 @@ def post_process_exported_models(
if is_accelerate_available():
logger.info("Deduplicating shared (tied) weights...")
for subpath, key in zip(onnx_files_subpaths, models_and_onnx_configs):
onnx_model = onnx.load(os.path.join(path, subpath))

torch_model = models_and_onnx_configs[key][0]
tied_params = find_tied_parameters(torch_model)
remove_duplicate_weights_from_tied_info(
onnx_model, torch_model, tied_params, save_path=os.path.join(path, subpath)
)

if len(tied_params) > 0:
onnx_model = onnx.load(os.path.join(path, subpath))
remove_duplicate_weights_from_tied_info(
onnx_model, torch_model, tied_params, save_path=os.path.join(path, subpath)
)
else:
logger.warning(
"Weight deduplication check in the ONNX export requires accelerate. Please install accelerate to run it."
Expand Down

0 comments on commit 3679079

Please sign in to comment.