From 6335599da3bad65609620d1f64512e8241fd5aaa Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Wed, 12 Feb 2025 18:51:55 +0100 Subject: [PATCH] Infer whether a model needs to be exported to ONNX or not (#2181) * add files matching patterns * rename * Infer if model needs to be exported to ONNX * adapt test * add test for export diffusers model * adapt test * fix for local files * set subfolder for local dir * force export in tests * check for all available models * add tests * fix model files loading when detected * fix warning message * refacto test * add warning when filename ignored * fix style * fix for windows --- .github/workflows/test_onnxruntime.yml | 2 +- optimum/onnxruntime/constants.py | 1 + optimum/onnxruntime/modeling_decoder.py | 113 +++++++------- optimum/onnxruntime/modeling_ort.py | 194 ++++++++++++------------ optimum/pipelines/pipelines_base.py | 22 +-- optimum/utils/file_utils.py | 22 +-- tests/onnxruntime/test_diffusion.py | 21 +-- tests/onnxruntime/test_modeling.py | 71 ++++++--- tests/onnxruntime/test_timm.py | 2 +- 9 files changed, 229 insertions(+), 219 deletions(-) diff --git a/.github/workflows/test_onnxruntime.yml b/.github/workflows/test_onnxruntime.yml index bf7f15e263..64eced83e8 100644 --- a/.github/workflows/test_onnxruntime.yml +++ b/.github/workflows/test_onnxruntime.yml @@ -64,4 +64,4 @@ jobs: run: | pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv -n auto env: - HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} + HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} \ No newline at end of file diff --git a/optimum/onnxruntime/constants.py b/optimum/onnxruntime/constants.py index edf755e876..4325dcd464 100644 --- a/optimum/onnxruntime/constants.py +++ b/optimum/onnxruntime/constants.py @@ -16,3 +16,4 @@ DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!(with_past|merged)).)*?\.onnx" DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx" DECODER_MERGED_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?merged(.*)?\.onnx" +ONNX_FILE_PATTERN = r".*\.onnx$" diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 9afa1bf19a..0099095eb1 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -14,6 +14,8 @@ """Classes handling causal-lm related architectures in ONNX Runtime.""" import logging +import os +import re from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -32,12 +34,17 @@ from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export from ..onnx.utils import check_model_uses_external_data from ..utils import NormalizedConfigManager, is_transformers_version -from ..utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST +from ..utils.file_utils import find_files_matching_pattern from ..utils.save_utils import maybe_save_preprocessors -from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN +from .constants import ( + DECODER_MERGED_ONNX_FILE_PATTERN, + DECODER_ONNX_FILE_PATTERN, + DECODER_WITH_PAST_ONNX_FILE_PATTERN, + ONNX_FILE_PATTERN, +) from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache -from .utils import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME +from .utils import ONNX_WEIGHTS_NAME if TYPE_CHECKING: @@ -400,7 +407,6 @@ def _from_pretrained( **kwargs, ) -> "ORTModelForCausalLM": generation_config = kwargs.pop("generation_config", None) - model_path = Path(model_id) # We do not implement the logic for use_cache=False, use_merged=True if use_cache is False: @@ -411,68 +417,69 @@ def _from_pretrained( ) use_merged = False - decoder_name = "decoder_file_name" if use_cache else "decoder_with_past_file_name" - decoder_file_name = kwargs.pop(decoder_name, None) + onnx_files = find_files_matching_pattern( + model_id, + ONNX_FILE_PATTERN, + glob_pattern="**/*.onnx", + subfolder=subfolder, + token=token, + revision=revision, + ) + + if len(onnx_files) == 0: + raise FileNotFoundError(f"Could not find any ONNX model file in {model_id}") - if decoder_file_name is not None: - logger.warning(f"The `{decoder_name}` argument is deprecated, please use `file_name` instead.") - file_name = file_name or decoder_file_name + if len(onnx_files) == 1: + subfolder = onnx_files[0].parent + _file_name = onnx_files[0].name + if file_name and file_name != _file_name: + raise FileNotFoundError(f"Trying to load {file_name} but only found {_file_name}") + file_name = _file_name - if file_name is None: - decoder_path = None - # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it), - # and use_merged = True (explicitely specified by the user) + else: + model_files = [] + # Check first for merged models and then for decoder / decoder_with_past models if use_merged is not False: - try: - decoder_path = ORTModelForCausalLM.infer_onnx_filename( - model_id, - [DECODER_MERGED_ONNX_FILE_PATTERN], - argument_name=None, - subfolder=subfolder, - token=token, - revision=revision, - ) - use_merged = True - file_name = decoder_path.name - except FileNotFoundError as e: - if use_merged is True: - raise FileNotFoundError( - "The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()" - " but no ONNX file for a merged decoder could be found in" - f" {str(Path(model_id, subfolder))}, with the error: {e}" - ) - use_merged = False + model_files = [p for p in onnx_files if re.search(DECODER_MERGED_ONNX_FILE_PATTERN, str(p))] + use_merged = len(model_files) != 0 if use_merged is False: pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN - # exclude decoder file for first iteration - decoder_path = ORTModelForCausalLM.infer_onnx_filename( - model_id, - [r"^((?!decoder).)*.onnx", pattern], - argument_name=None, - subfolder=subfolder, - token=token, - revision=revision, - ) - file_name = decoder_path.name + model_files = [p for p in onnx_files if re.search(pattern, str(p))] - if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST: - raise ValueError( - f"ONNX Runtime inference using {ONNX_DECODER_WITH_PAST_NAME} has been deprecated for {config.model_type} architecture. Please re-export your model with optimum>=1.14.0 or set use_cache=False. For details about the deprecation, please refer to https://github.com/huggingface/optimum/releases/tag/v1.14.0." + # if file_name is specified we don't filter legacy models + if not model_files or file_name: + model_files = onnx_files + else: + logger.warning( + f"Legacy models found in {model_files} will be loaded. " + "Legacy models will be deprecated in the next version of optimum, please re-export your model" ) + _file_name = model_files[0].name + subfolder = model_files[0].parent - regular_file_names = [] - for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]: - regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(name) + defaut_file_name = file_name or "model.onnx" + for file in model_files: + if file.name == defaut_file_name: + _file_name = file.name + subfolder = file.parent + break - if file_name not in regular_file_names: + file_name = _file_name + + if len(model_files) > 1: logger.warning( - f"The ONNX file {file_name} is not a regular name used in optimum.onnxruntime that are {regular_file_names}, the " - f"{cls.__name__} might not behave as expected." + f"Too many ONNX model files were found in {' ,'.join(map(str, model_files))}. " + "specify which one to load by using the `file_name` and/or the `subfolder` arguments. " + f"Loading the file {file_name} in the subfolder {subfolder}." ) + if os.path.isdir(model_id): + model_id = subfolder + subfolder = "" + model_cache_path, preprocessors = cls._cached_file( - model_path=model_path, + model_path=model_id, token=token, revision=revision, force_download=force_download, @@ -481,7 +488,7 @@ def _from_pretrained( subfolder=subfolder, local_files_only=local_files_only, ) - new_model_save_dir = model_cache_path.parent + new_model_save_dir = Path(model_cache_path).parent # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it # instead of the path only. diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index e9633343c7..37c1af0db4 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -14,6 +14,7 @@ """ORTModelForXXX classes, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers.""" import logging +import os import re import shutil import warnings @@ -58,6 +59,7 @@ TokenClassifierOutput, XVectorOutput, ) +from transformers.utils import is_offline_mode import onnxruntime as ort @@ -67,9 +69,9 @@ from ..onnx.utils import _get_external_data_paths from ..utils.file_utils import find_files_matching_pattern from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors +from .constants import ONNX_FILE_PATTERN from .io_binding import IOBindingHelper, TypeHelper from .utils import ( - ONNX_WEIGHTS_NAME, check_io_binding, get_device_for_provider, get_provider_for_device, @@ -430,20 +432,10 @@ def infer_onnx_filename( patterns: List[str], argument_name: str, subfolder: str = "", - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, fail_if_not_found: bool = True, ) -> str: - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - onnx_files = [] for pattern in patterns: onnx_files = find_files_matching_pattern( @@ -478,7 +470,6 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -493,47 +484,44 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ) -> "ORTModel": - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token + defaut_file_name = file_name or "model.onnx" - model_path = Path(model_id) - regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(ONNX_WEIGHTS_NAME) + onnx_files = find_files_matching_pattern( + model_id, + ONNX_FILE_PATTERN, + glob_pattern="**/*.onnx", + subfolder=subfolder, + token=token, + revision=revision, + ) - if file_name is None: - if model_path.is_dir(): - onnx_files = list((model_path / subfolder).glob("*.onnx")) - else: - repo_files, _ = TasksManager.get_model_files( - model_id, revision=revision, cache_dir=cache_dir, token=token - ) - repo_files = map(Path, repo_files) - pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx" - onnx_files = [p for p in repo_files if p.match(pattern)] + if len(onnx_files) == 0: + raise FileNotFoundError(f"Could not find any ONNX model file in {model_id}") + if len(onnx_files) == 1 and file_name and file_name != onnx_files[0].name: + raise FileNotFoundError(f"Trying to load {file_name} but only found {onnx_files[0].name}") - if len(onnx_files) == 0: - raise FileNotFoundError(f"Could not find any ONNX model file in {model_path}") - elif len(onnx_files) > 1: - raise RuntimeError( - f"Too many ONNX model files were found in {model_path}, specify which one to load by using the " - "file_name argument." - ) - else: - file_name = onnx_files[0].name + file_name = onnx_files[0].name + subfolder = onnx_files[0].parent + + if len(onnx_files) > 1: + for file in onnx_files: + if file.name == defaut_file_name: + file_name = file.name + subfolder = file.parent + break - if file_name not in regular_onnx_filenames: logger.warning( - f"The ONNX file {file_name} is not a regular name used in optimum.onnxruntime, the ORTModel might " - "not behave as expected." + f"Too many ONNX model files were found in {' ,'.join(map(str, onnx_files))}. " + "specify which one to load by using the `file_name` and/or the `subfolder` arguments. " + f"Loading the file {file_name} in the subfolder {subfolder}." ) + if os.path.isdir(model_id): + model_id = subfolder + subfolder = "" + model_cache_path, preprocessors = cls._cached_file( - model_path=model_path, + model_path=model_id, token=token, revision=revision, force_download=force_download, @@ -542,7 +530,7 @@ def _from_pretrained( subfolder=subfolder, local_files_only=local_files_only, ) - new_model_save_dir = model_cache_path.parent + new_model_save_dir = Path(model_cache_path).parent # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it # instead of the path only. @@ -569,7 +557,6 @@ def _from_transformers( cls, model_id: str, config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -585,15 +572,6 @@ def _from_transformers( ) -> "ORTModel": """The method will be deprecated in future releases.""" - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - return cls._export( model_id=model_id, config=config, @@ -616,7 +594,6 @@ def _export( cls, model_id: str, config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -630,15 +607,6 @@ def _export( use_io_binding: Optional[bool] = None, task: Optional[str] = None, ) -> "ORTModel": - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - if task is None: task = cls._auto_model_to_task(cls.auto_model_class) @@ -684,6 +652,7 @@ def from_pretrained( subfolder: str = "", config: Optional["PretrainedConfig"] = None, local_files_only: bool = False, + revision: Optional[str] = None, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict[str, Any]] = None, @@ -731,15 +700,67 @@ def from_pretrained( raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") token = use_auth_token + if isinstance(model_id, Path): + model_id = model_id.as_posix() + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: setting `local_files_only=True`") + local_files_only = True + + _export = export + try: + if local_files_only and not os.path.isdir(model_id): + object_id = model_id.replace("/", "--") + cached_model_dir = os.path.join(cache_dir, f"models--{object_id}") + refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main") + with open(refs_file) as f: + _revision = f.read() + model_id = os.path.join(cached_model_dir, "snapshots", _revision) + + onnx_files = find_files_matching_pattern( + model_id, + pattern=ONNX_FILE_PATTERN, + glob_pattern="**/*.onnx", + subfolder=subfolder, + token=token, + revision=revision, + ) + + _export = len(onnx_files) == 0 + if _export ^ export: + if export: + logger.warning( + f"The model {model_id} was already converted to ONNX but got `export=True`, the model will be converted to ONNX once again. " + "Don't forget to save the resulting model with `.save_pretrained()`" + ) + _export = True + else: + logger.warning( + f"No ONNX files were found for {model_id}, setting `export=True` to convert the model to ONNX. " + "Don't forget to save the resulting model with `.save_pretrained()`" + ) + except Exception as exception: + logger.warning( + f"Could not infer whether the model was already converted or not to ONNX, keeping `export={export}`.\n{exception}" + ) + + if _export: + file_name = kwargs.pop("file_name", None) + if file_name is not None: + logger.warning( + f"`file_name` was set to `{file_name}` but will be ignored as the model will be converted to ONNX" + ) + return super().from_pretrained( model_id, - export=export, + export=_export, force_download=force_download, token=token, cache_dir=cache_dir, subfolder=subfolder, config=config, local_files_only=local_files_only, + revision=revision, provider=provider, session_options=session_options, provider_options=provider_options, @@ -961,7 +982,6 @@ def _prepare_onnx_outputs( @staticmethod def _cached_file( model_path: Union[Path, str], - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -970,23 +990,18 @@ def _cached_file( subfolder: str = "", local_files_only: bool = False, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token + if isinstance(model_path, Path): + model_path = model_path.as_posix() - model_path = Path(model_path) # locates a file in a local folder and repo, downloads and cache it if necessary. - if model_path.is_dir(): - model_cache_path = model_path / subfolder / file_name - preprocessors = maybe_load_preprocessors(model_path.as_posix()) + if os.path.isdir(model_path): + model_cache_path = os.path.join(model_path, subfolder, file_name) + preprocessors = maybe_load_preprocessors(model_path) else: + model_path = model_path.replace(os.sep, "/") + subfolder = str(subfolder).replace(os.sep, "/") model_cache_path = hf_hub_download( - repo_id=model_path.as_posix(), + repo_id=model_path, filename=file_name, subfolder=subfolder, token=token, @@ -998,7 +1013,7 @@ def _cached_file( # try download external data try: hf_hub_download( - repo_id=model_path.as_posix(), + repo_id=model_path, subfolder=subfolder, filename=file_name + "_data", token=token, @@ -1011,10 +1026,9 @@ def _cached_file( # model doesn't use external data pass - model_cache_path = Path(model_cache_path) - preprocessors = maybe_load_preprocessors(model_path.as_posix(), subfolder=subfolder) + preprocessors = maybe_load_preprocessors(model_path, subfolder=subfolder) - return model_cache_path, preprocessors + return Path(model_cache_path), preprocessors def can_generate(self) -> bool: """ @@ -1124,7 +1138,6 @@ def _export( cls, model_id: str, config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -1138,15 +1151,6 @@ def _export( use_io_binding: Optional[bool] = None, task: Optional[str] = None, ) -> "ORTModel": - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - if task is None: task = cls._auto_model_to_task(cls.auto_model_class) diff --git a/optimum/pipelines/pipelines_base.py b/optimum/pipelines/pipelines_base.py index 0016c73ff0..36c75a0a1d 100644 --- a/optimum/pipelines/pipelines_base.py +++ b/optimum/pipelines/pipelines_base.py @@ -47,7 +47,6 @@ from ..bettertransformer import BetterTransformer from ..utils import is_onnxruntime_available, is_transformers_version -from ..utils.file_utils import find_files_matching_pattern if is_onnxruntime_available(): @@ -244,26 +243,9 @@ def load_ort_pipeline( model_id = SUPPORTED_TASKS[targeted_task]["default"] model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained(model_id, export=True) elif isinstance(model, str): - from ..onnxruntime.modeling_seq2seq import ENCODER_ONNX_FILE_PATTERN, ORTModelForConditionalGeneration - - model_id = model - ort_model_class = SUPPORTED_TASKS[targeted_task]["class"][0] - - if issubclass(ort_model_class, ORTModelForConditionalGeneration): - pattern = ENCODER_ONNX_FILE_PATTERN - else: - pattern = ".+?.onnx" - - onnx_files = find_files_matching_pattern( - model, - pattern, - glob_pattern="**/*.onnx", - subfolder=subfolder, - token=token, - revision=revision, + model = SUPPORTED_TASKS[targeted_task]["class"][0].from_pretrained( + model, revision=revision, subfolder=subfolder, token=token, **model_kwargs ) - export = len(onnx_files) == 0 - model = ort_model_class.from_pretrained(model, export=export, **model_kwargs) elif isinstance(model, ORTModel): if tokenizer is None and load_tokenizer: for preprocessor in model.preprocessors: diff --git a/optimum/utils/file_utils.py b/optimum/utils/file_utils.py index 16190709f8..d1908e1d89 100644 --- a/optimum/utils/file_utils.py +++ b/optimum/utils/file_utils.py @@ -14,6 +14,7 @@ # limitations under the License. """Utility functions related to both local files and files on the Hugging Face Hub.""" +import os import re import warnings from pathlib import Path @@ -22,11 +23,6 @@ import huggingface_hub from huggingface_hub import get_hf_file_metadata, hf_hub_url -from ..utils import logging - - -logger = logging.get_logger(__name__) - def validate_file_exists( model_name_or_path: Union[str, Path], filename: str, subfolder: str = "", revision: Optional[str] = None @@ -91,17 +87,13 @@ def find_files_matching_pattern( raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") token = use_auth_token - model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path - pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) - if model_path.is_dir(): - path = model_path - files = model_path.glob(glob_pattern) + model_path = str(model_name_or_path) if isinstance(model_name_or_path, Path) else model_name_or_path + pattern = re.compile(subfolder + pattern) + if os.path.isdir(model_path): + files = Path(model_path).glob(glob_pattern) files = [p for p in files if re.search(pattern, str(p))] else: - path = model_name_or_path - repo_files = map(Path, huggingface_hub.list_repo_files(model_name_or_path, revision=revision, token=token)) - if subfolder != "": - path = f"{path}/{subfolder}" - files = [Path(p) for p in repo_files if re.match(pattern, str(p))] + repo_files = huggingface_hub.list_repo_files(model_path, revision=revision, token=token) + files = [Path(p) for p in repo_files if re.match(pattern, p)] return files diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py index 749e078456..aa88117898 100644 --- a/tests/onnxruntime/test_diffusion.py +++ b/tests/onnxruntime/test_diffusion.py @@ -113,7 +113,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1): @require_diffusers def test_load_vanilla_model_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"]) self.assertIn( f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) @@ -138,10 +138,7 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_num_images_per_prompt(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) for batch_size in [1, 3]: for height in [16, 32]: @@ -375,7 +372,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_ @require_diffusers def test_load_vanilla_model_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"]) self.assertIn( f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) @@ -395,10 +392,7 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_num_images_per_prompt(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) for batch_size in [1, 3]: for height in [16, 32]: @@ -617,7 +611,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_ @require_diffusers def test_load_vanilla_model_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"]) self.assertIn( f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) @@ -637,10 +631,7 @@ def test_ort_pipeline_class_dispatch(self, model_arch: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_num_images_per_prompt(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) for batch_size in [1, 3]: for height in [16, 32]: diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 9ea0483e35..4dd7f91099 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -141,6 +141,44 @@ def __init__(self, *args, **kwargs): self.TINY_ONNX_SEQ2SEQ_MODEL_ID = "fxmarty/sshleifer-tiny-mbart-onnx" self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID = "optimum-internal-testing/tiny-stable-diffusion-onnx" + @parameterized.expand((ORTModelForCausalLM, ORTModel)) + def test_load_model_from_hub_infer_onnx_model(self, model_cls): + model_id = "optimum-internal-testing/tiny-random-llama" + file_name = "model_optimized.onnx" + model = model_cls.from_pretrained(model_id) + self.assertEqual(model.model_path.name, "model.onnx") + + model = model_cls.from_pretrained(model_id, revision="onnx") + self.assertEqual(model.model_path.name, "model.onnx") + + model = model_cls.from_pretrained(model_id, revision="onnx", file_name=file_name) + self.assertEqual(model.model_path.name, file_name) + + model = model_cls.from_pretrained(model_id, revision="merged-onnx", file_name=file_name) + self.assertEqual(model.model_path.name, file_name) + + if model_cls is ORTModelForCausalLM: + model = model_cls.from_pretrained(model_id, revision="merged-onnx") + self.assertEqual(model.model_path.name, "decoder_model_merged.onnx") + + model = model_cls.from_pretrained(self.LOCAL_MODEL_PATH, use_cache=False, use_io_binding=False) + self.assertEqual(model.model_path.name, "model.onnx") + + model = model_cls.from_pretrained(model_id, revision="merged-onnx", subfolder="subfolder") + self.assertEqual(model.model_path.name, "model.onnx") + + model = model_cls.from_pretrained(model_id, revision="merged-onnx", subfolder="subfolder", file_name=file_name) + self.assertEqual(model.model_path.name, file_name) + + model = model_cls.from_pretrained(model_id, revision="merged-onnx", file_name="decoder_with_past_model.onnx") + self.assertEqual(model.model_path.name, "decoder_with_past_model.onnx") + + model = model_cls.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + self.assertEqual(model.model_path.name, "model.onnx") + + with self.assertRaises(FileNotFoundError): + model_cls.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM", file_name="test.onnx") + def test_load_model_from_local_path(self): model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH) self.assertIsInstance(model.model, onnxruntime.InferenceSession) @@ -154,7 +192,8 @@ def test_load_model_from_hub(self): def test_load_model_from_hub_subfolder(self): # does not pass with ORTModel as it does not have export_feature attribute model = ORTModelForSequenceClassification.from_pretrained( - "fxmarty/tiny-bert-sst2-distilled-subfolder", subfolder="my_subfolder", export=True + "fxmarty/tiny-bert-sst2-distilled-subfolder", + subfolder="my_subfolder", ) self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) @@ -164,9 +203,7 @@ def test_load_model_from_hub_subfolder(self): self.assertIsInstance(model.config, PretrainedConfig) def test_load_seq2seq_model_from_hub_subfolder(self): - model = ORTModelForSeq2SeqLM.from_pretrained( - "fxmarty/tiny-mbart-subfolder", subfolder="my_folder", export=True - ) + model = ORTModelForSeq2SeqLM.from_pretrained("fxmarty/tiny-mbart-subfolder", subfolder="my_folder") self.assertIsInstance(model.encoder, ORTEncoder) self.assertIsInstance(model.decoder, ORTDecoderForSeq2Seq) self.assertIsInstance(model.decoder_with_past, ORTDecoderForSeq2Seq) @@ -383,8 +420,7 @@ def test_load_stable_diffusion_model_unknown_provider(self): ) def test_load_model_from_hub_without_onnx_model(self): - with self.assertRaises(FileNotFoundError): - ORTModel.from_pretrained(self.FAIL_ONNX_MODEL_ID) + ORTModel.from_pretrained(self.FAIL_ONNX_MODEL_ID) def test_model_on_cpu(self): model = ORTModel.from_pretrained(self.ONNX_MODEL_ID) @@ -1317,7 +1353,7 @@ class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForQuestionAnswering.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForQuestionAnswering.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -1519,7 +1555,7 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForMaskedLM.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForMaskedLM.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -1705,7 +1741,7 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForSequenceClassification.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForSequenceClassification.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -1912,7 +1948,7 @@ class ORTModelForTokenClassificationIntegrationTest(ORTModelTestMixin): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForTokenClassification.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForTokenClassification.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -2433,7 +2469,6 @@ def test_load_model_from_hub_onnx(self): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForCausalLM.from_pretrained(MODEL_NAMES["vit"], export=True) - self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -2894,8 +2929,7 @@ def _get_onnx_model_dir(self, model_id, model_arch, test_name): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForImageClassification.from_pretrained(MODEL_NAMES["t5"], export=True) - + _ = ORTModelForImageClassification.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -3071,7 +3105,7 @@ class ORTModelForSemanticSegmentationIntegrationTest(ORTModelTestMixin): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForSemanticSegmentation.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForSemanticSegmentation.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -3258,7 +3292,7 @@ def _generate_random_audio_data(self): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForAudioClassification.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForAudioClassification.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -3442,7 +3476,7 @@ def _generate_random_audio_data(self): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForCTC.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForCTC.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -3544,7 +3578,7 @@ def _generate_random_audio_data(self): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForAudioXVector.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForAudioXVector.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -3641,7 +3675,7 @@ def _generate_random_audio_data(self): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - _ = ORTModelForAudioFrameClassification.from_pretrained(MODEL_NAMES["t5"], export=True) + _ = ORTModelForAudioFrameClassification.from_pretrained(MODEL_NAMES["t5"]) self.assertIn("only supports the tasks", str(context.exception)) @@ -4829,7 +4863,6 @@ def _get_preprocessors(self, model_id): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForImageToImage.from_pretrained(MODEL_NAMES["bert"], export=True) - self.assertIn("only supports the tasks", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) diff --git a/tests/onnxruntime/test_timm.py b/tests/onnxruntime/test_timm.py index c51bcc01a0..ddb2d0f297 100644 --- a/tests/onnxruntime/test_timm.py +++ b/tests/onnxruntime/test_timm.py @@ -55,7 +55,7 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin): @pytest.mark.run_slow @slow def test_compare_to_timm(self, model_id): - onnx_model = ORTModelForImageClassification.from_pretrained(model_id, export=True) + onnx_model = ORTModelForImageClassification.from_pretrained(model_id) self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig)