Skip to content

Commit

Permalink
fix external data
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 19, 2023
1 parent e5fd9f8 commit addad92
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
parse_device,
validate_provider_availability,
)

from ..onnx.utils import check_model_uses_external_data

if TYPE_CHECKING:
from transformers import PretrainedConfig
Expand Down Expand Up @@ -1033,9 +1033,14 @@ def _from_pretrained(

##################################################################################################

# Since v1.7.0 decoder with past models have fixed sequence length of 1
# Since v1.7.0 decoder with past models have fixed sequence length of 1
# To keep these models compatible we set this dimension to dynamic
onnx_model = onnx.load(model_cache_path)
onnx_model = onnx.load(str(model_cache_path), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)

if model_uses_external_data:
onnx_model = onnx.load(str(model_cache_path), load_external_data=True)

input_dims = {
node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.input
Expand All @@ -1047,9 +1052,17 @@ def _from_pretrained(
for node in onnx_model.graph.output
}
output_dims["logits"][1] = "sequence_length"
static_model = onnx.load(model_cache_path)
updated_model = update_model_dims.update_inputs_outputs_dims(static_model, input_dims, output_dims)
onnx.save(updated_model, model_cache_path)
onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims)

onnx.save(
onnx_model,
str(model_cache_path),
save_as_external_data=model_uses_external_data,
all_tensors_to_one_file=True,
location=model_cache_path.name + "_data",
size_threshold=0,
)
del onnx_model

model = ORTModel.load_model(
model_cache_path,
Expand Down

0 comments on commit addad92

Please sign in to comment.