From addad9264ae5dab1e37adf94d66ce97a8841bee9 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 19 Sep 2023 15:26:08 +0200 Subject: [PATCH] fix external data --- optimum/onnxruntime/modeling_decoder.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 66190069e5d..222fd9a852a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -49,7 +49,7 @@ parse_device, validate_provider_availability, ) - +from ..onnx.utils import check_model_uses_external_data if TYPE_CHECKING: from transformers import PretrainedConfig @@ -1033,9 +1033,14 @@ def _from_pretrained( ################################################################################################## - # Since v1.7.0 decoder with past models have fixed sequence length of 1 + # Since v1.7.0 decoder with past models have fixed sequence length of 1 # To keep these models compatible we set this dimension to dynamic - onnx_model = onnx.load(model_cache_path) + onnx_model = onnx.load(str(model_cache_path), load_external_data=False) + model_uses_external_data = check_model_uses_external_data(onnx_model) + + if model_uses_external_data: + onnx_model = onnx.load(str(model_cache_path), load_external_data=True) + input_dims = { node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.input @@ -1047,9 +1052,17 @@ def _from_pretrained( for node in onnx_model.graph.output } output_dims["logits"][1] = "sequence_length" - static_model = onnx.load(model_cache_path) - updated_model = update_model_dims.update_inputs_outputs_dims(static_model, input_dims, output_dims) - onnx.save(updated_model, model_cache_path) + onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims) + + onnx.save( + onnx_model, + str(model_cache_path), + save_as_external_data=model_uses_external_data, + all_tensors_to_one_file=True, + location=model_cache_path.name + "_data", + size_threshold=0, + ) + del onnx_model model = ORTModel.load_model( model_cache_path,