Skip to content

Commit

Permalink
Infer whether a model needs to be exported to ONNX or not (#2181)
Browse files Browse the repository at this point in the history
* add files matching patterns

* rename

* Infer if model needs to be exported to ONNX

* adapt test

* add test for export diffusers model

* adapt test

* fix for local files

* set subfolder for local dir

* force export in tests

* check for all available models

* add tests

* fix model files loading when detected

* fix warning message

* refacto test

* add warning when filename ignored

* fix style

* fix for windows
  • Loading branch information
echarlaix authored Feb 12, 2025
1 parent 512d5c6 commit 6335599
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 219 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ jobs:
run: |
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv -n auto
env:
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
1 change: 1 addition & 0 deletions optimum/onnxruntime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!(with_past|merged)).)*?\.onnx"
DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx"
DECODER_MERGED_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?merged(.*)?\.onnx"
ONNX_FILE_PATTERN = r".*\.onnx$"
113 changes: 60 additions & 53 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""Classes handling causal-lm related architectures in ONNX Runtime."""

import logging
import os
import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
Expand All @@ -32,12 +34,17 @@
from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export
from ..onnx.utils import check_model_uses_external_data
from ..utils import NormalizedConfigManager, is_transformers_version
from ..utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST
from ..utils.file_utils import find_files_matching_pattern
from ..utils.save_utils import maybe_save_preprocessors
from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN
from .constants import (
DECODER_MERGED_ONNX_FILE_PATTERN,
DECODER_ONNX_FILE_PATTERN,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
ONNX_FILE_PATTERN,
)
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache
from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME
from .utils import ONNX_WEIGHTS_NAME


if TYPE_CHECKING:
Expand Down Expand Up @@ -400,7 +407,6 @@ def _from_pretrained(
**kwargs,
) -> "ORTModelForCausalLM":
generation_config = kwargs.pop("generation_config", None)
model_path = Path(model_id)

# We do not implement the logic for use_cache=False, use_merged=True
if use_cache is False:
Expand All @@ -411,68 +417,69 @@ def _from_pretrained(
)
use_merged = False

decoder_name = "decoder_file_name" if use_cache else "decoder_with_past_file_name"
decoder_file_name = kwargs.pop(decoder_name, None)
onnx_files = find_files_matching_pattern(
model_id,
ONNX_FILE_PATTERN,
glob_pattern="**/*.onnx",
subfolder=subfolder,
token=token,
revision=revision,
)

if len(onnx_files) == 0:
raise FileNotFoundError(f"Could not find any ONNX model file in {model_id}")

if decoder_file_name is not None:
logger.warning(f"The `{decoder_name}` argument is deprecated, please use `file_name` instead.")
file_name = file_name or decoder_file_name
if len(onnx_files) == 1:
subfolder = onnx_files[0].parent
_file_name = onnx_files[0].name
if file_name and file_name != _file_name:
raise FileNotFoundError(f"Trying to load {file_name} but only found {_file_name}")
file_name = _file_name

if file_name is None:
decoder_path = None
# We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it),
# and use_merged = True (explicitely specified by the user)
else:
model_files = []
# Check first for merged models and then for decoder / decoder_with_past models
if use_merged is not False:
try:
decoder_path = ORTModelForCausalLM.infer_onnx_filename(
model_id,
[DECODER_MERGED_ONNX_FILE_PATTERN],
argument_name=None,
subfolder=subfolder,
token=token,
revision=revision,
)
use_merged = True
file_name = decoder_path.name
except FileNotFoundError as e:
if use_merged is True:
raise FileNotFoundError(
"The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()"
" but no ONNX file for a merged decoder could be found in"
f" {str(Path(model_id, subfolder))}, with the error: {e}"
)
use_merged = False
model_files = [p for p in onnx_files if re.search(DECODER_MERGED_ONNX_FILE_PATTERN, str(p))]
use_merged = len(model_files) != 0

if use_merged is False:
pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
# exclude decoder file for first iteration
decoder_path = ORTModelForCausalLM.infer_onnx_filename(
model_id,
[r"^((?!decoder).)*.onnx", pattern],
argument_name=None,
subfolder=subfolder,
token=token,
revision=revision,
)
file_name = decoder_path.name
model_files = [p for p in onnx_files if re.search(pattern, str(p))]

if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST:
raise ValueError(
f"ONNX Runtime inference using {ONNX_DECODER_WITH_PAST_NAME} has been deprecated for {config.model_type} architecture. Please re-export your model with optimum>=1.14.0 or set use_cache=False. For details about the deprecation, please refer to https://github.com/huggingface/optimum/releases/tag/v1.14.0."
# if file_name is specified we don't filter legacy models
if not model_files or file_name:
model_files = onnx_files
else:
logger.warning(
f"Legacy models found in {model_files} will be loaded. "
"Legacy models will be deprecated in the next version of optimum, please re-export your model"
)
_file_name = model_files[0].name
subfolder = model_files[0].parent

regular_file_names = []
for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]:
regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(name)
defaut_file_name = file_name or "model.onnx"
for file in model_files:
if file.name == defaut_file_name:
_file_name = file.name
subfolder = file.parent
break

if file_name not in regular_file_names:
file_name = _file_name

if len(model_files) > 1:
logger.warning(
f"The ONNX file {file_name} is not a regular name used in optimum.onnxruntime that are {regular_file_names}, the "
f"{cls.__name__} might not behave as expected."
f"Too many ONNX model files were found in {' ,'.join(map(str, model_files))}. "
"specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
f"Loading the file {file_name} in the subfolder {subfolder}."
)

if os.path.isdir(model_id):
model_id = subfolder
subfolder = ""

model_cache_path, preprocessors = cls._cached_file(
model_path=model_path,
model_path=model_id,
token=token,
revision=revision,
force_download=force_download,
Expand All @@ -481,7 +488,7 @@ def _from_pretrained(
subfolder=subfolder,
local_files_only=local_files_only,
)
new_model_save_dir = model_cache_path.parent
new_model_save_dir = Path(model_cache_path).parent

# model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
# instead of the path only.
Expand Down
Loading

0 comments on commit 6335599

Please sign in to comment.