From 524b6682b579e963a98680199dc8431805e0dfb8 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 25 Sep 2023 14:24:12 +0200 Subject: [PATCH] add cached file --- optimum/onnxruntime/modeling_decoder.py | 50 +++-------- optimum/onnxruntime/modeling_ort.py | 110 +++++++++++++++--------- 2 files changed, 78 insertions(+), 82 deletions(-) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index e130be0bca6..b13b7ed7d63 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -1008,51 +1008,23 @@ def _from_pretrained( else: init_cls = ORTModelForCausalLM - ################################################################################################## - - preprocessors = None - if model_path.is_dir(): - model_cache_path = model_path / file_name - new_model_save_dir = model_path - preprocessors = maybe_load_preprocessors(model_id) - else: - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - - # try download external data - try: - hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=file_name + "_data", - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - except EntryNotFoundError: - # model doesn't use external data - pass - model_cache_path = Path(model_cache_path) - new_model_save_dir = model_cache_path.parent - preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + model_cache_path, preprocessors = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + new_model_save_dir = model_cache_path.parent # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it # instead of the path only. if model_save_dir is None: model_save_dir = new_model_save_dir - ################################################################################################## - # Since v1.7.0 decoder with past models have fixed sequence length of 1 # To keep these models compatible we set this dimension to dynamic onnx_model = onnx.load(str(model_cache_path), load_external_data=False) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index b190432fe6d..b58a37eb43a 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -486,55 +486,30 @@ def _from_pretrained( "not behave as expected." ) - preprocessors = None - if model_path.is_dir(): - model = ORTModel.load_model( - model_path / file_name, - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - new_model_save_dir = model_path - preprocessors = maybe_load_preprocessors(model_id) - else: - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - - # try download external data - try: - hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=file_name + "_data", - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - except EntryNotFoundError: - # model doesn't use external data - pass - - model = ORTModel.load_model( - model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options - ) - new_model_save_dir = Path(model_cache_path).parent - preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + model_cache_path, preprocessors = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + new_model_save_dir = model_cache_path.parent # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it # instead of the path only. if model_save_dir is None: model_save_dir = new_model_save_dir + model = ORTModel.load_model( + model_cache_path, + provider=provider, + session_options=session_options, + provider_options=provider_options, + ) + return cls( model=model, config=config, @@ -828,6 +803,55 @@ def raise_on_numpy_input_io_binding(self, use_torch: bool): " with model.use_io_binding = False, or pass torch.Tensor inputs instead." ) + @staticmethod + def _cached_file( + model_path: Union[Path, str], + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + ): + model_path = Path(model_path) + + # locates a file in a local folder and repo, downloads and cache it if necessary. + if model_path.is_dir(): + model_cache_path = model_path / file_name + preprocessors = maybe_load_preprocessors(model_path.as_posix()) + else: + model_cache_path = hf_hub_download( + repo_id=model_path.as_posix(), + filename=file_name, + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + # try download external data + try: + hf_hub_download( + repo_id=model_path.as_posix(), + subfolder=subfolder, + filename=file_name + "_data", + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + # model doesn't use external data + pass + + model_cache_path = Path(model_cache_path) + preprocessors = maybe_load_preprocessors(model_path.as_posix(), subfolder=subfolder) + + return model_cache_path, preprocessors + FEATURE_EXTRACTION_EXAMPLE = r""" Example of feature extraction: