Skip to content

Commit

Permalink
add cached file
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 25, 2023
1 parent c13a170 commit 524b668
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 82 deletions.
50 changes: 11 additions & 39 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,51 +1008,23 @@ def _from_pretrained(
else:
init_cls = ORTModelForCausalLM

##################################################################################################

preprocessors = None
if model_path.is_dir():
model_cache_path = model_path / file_name
new_model_save_dir = model_path
preprocessors = maybe_load_preprocessors(model_id)
else:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)

# try download external data
try:
hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=file_name + "_data",
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass
model_cache_path = Path(model_cache_path)
new_model_save_dir = model_cache_path.parent
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
model_cache_path, preprocessors = cls._cached_file(
model_path=model_path,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
file_name=file_name,
subfolder=subfolder,
local_files_only=local_files_only,
)
new_model_save_dir = model_cache_path.parent

# model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
# instead of the path only.
if model_save_dir is None:
model_save_dir = new_model_save_dir

##################################################################################################

# Since v1.7.0 decoder with past models have fixed sequence length of 1
# To keep these models compatible we set this dimension to dynamic
onnx_model = onnx.load(str(model_cache_path), load_external_data=False)
Expand Down
110 changes: 67 additions & 43 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,55 +486,30 @@ def _from_pretrained(
"not behave as expected."
)

preprocessors = None
if model_path.is_dir():
model = ORTModel.load_model(
model_path / file_name,
provider=provider,
session_options=session_options,
provider_options=provider_options,
)
new_model_save_dir = model_path
preprocessors = maybe_load_preprocessors(model_id)
else:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)

# try download external data
try:
hf_hub_download(
repo_id=model_id,
subfolder=subfolder,
filename=file_name + "_data",
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass

model = ORTModel.load_model(
model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options
)
new_model_save_dir = Path(model_cache_path).parent
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
model_cache_path, preprocessors = cls._cached_file(
model_path=model_path,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
file_name=file_name,
subfolder=subfolder,
local_files_only=local_files_only,
)
new_model_save_dir = model_cache_path.parent

# model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
# instead of the path only.
if model_save_dir is None:
model_save_dir = new_model_save_dir

model = ORTModel.load_model(
model_cache_path,
provider=provider,
session_options=session_options,
provider_options=provider_options,
)

return cls(
model=model,
config=config,
Expand Down Expand Up @@ -828,6 +803,55 @@ def raise_on_numpy_input_io_binding(self, use_torch: bool):
" with model.use_io_binding = False, or pass torch.Tensor inputs instead."
)

@staticmethod
def _cached_file(
model_path: Union[Path, str],
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
):
model_path = Path(model_path)

# locates a file in a local folder and repo, downloads and cache it if necessary.
if model_path.is_dir():
model_cache_path = model_path / file_name
preprocessors = maybe_load_preprocessors(model_path.as_posix())
else:
model_cache_path = hf_hub_download(
repo_id=model_path.as_posix(),
filename=file_name,
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
# try download external data
try:
hf_hub_download(
repo_id=model_path.as_posix(),
subfolder=subfolder,
filename=file_name + "_data",
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EntryNotFoundError:
# model doesn't use external data
pass

model_cache_path = Path(model_cache_path)
preprocessors = maybe_load_preprocessors(model_path.as_posix(), subfolder=subfolder)

return model_cache_path, preprocessors


FEATURE_EXTRACTION_EXAMPLE = r"""
Example of feature extraction:
Expand Down

0 comments on commit 524b668

Please sign in to comment.