Skip to content

Commit

Permalink
raise when unsupported model
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 25, 2023
1 parent ed8e74f commit c13a170
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,20 +962,33 @@ def _from_pretrained(
use_merged = False

if use_merged is False:
pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
# exclude decoder file for first iteration
decoder_path = ORTModelForCausalLM.infer_onnx_filename(
model_id,
[
r"^((?!decoder).)*.onnx",
DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN,
],
"file_name",
[r"^((?!decoder).)*.onnx", pattern],
argument_name=None,
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
file_name = decoder_path.name

MODEL_TO_PATCH_FOR_PAST = {
"bloom",
"mpt",
"llama",
"blenderbot-small",
"blenderbot",
"opt",
"pegasus",
"bart",
}
if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST:
raise ValueError(
f"{ONNX_DECODER_WITH_PAST_NAME} not supported for the following architecture : {', '.join(MODEL_TO_PATCH_FOR_PAST)}. Please re-export your model or set use_cache=False."
)

regular_file_names = []
for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]:
regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(name)
Expand Down

0 comments on commit c13a170

Please sign in to comment.