diff --git a/lol_hub.py b/lol_hub.py new file mode 100644 index 00000000..86a5d7d4 --- /dev/null +++ b/lol_hub.py @@ -0,0 +1,712 @@ +import errno +import hashlib +import json +import os +import re +import shutil +import sys +import tempfile +import warnings +import zipfile +from pathlib import Path +from urllib.error import HTTPError +from urllib.parse import urlparse # noqa: F401 +from urllib.request import urlopen, Request + +import torch + +_DEFAULT_SECURITY = os.environ["DEFAULT_SECURITY"] +if _DEFAULT_SECURITY == "True": + _DEFAULT_SECURITY = True +if _DEFAULT_SECURITY == "False": + _DEFAULT_SECURITY = False +if _DEFAULT_SECURITY == "None": + _DEFAULT_SECURITY = None +if _DEFAULT_SECURITY == "check": + _DEFAULT_SECURITY = "check" +try: + from tqdm.auto import tqdm # automatically select proper tqdm submodule if available +except ImportError: + try: + from tqdm import tqdm + except ImportError: + # fake tqdm if it's not installed + class tqdm(object): # type: ignore[no-redef] + + def __init__(self, total=None, disable=False, + unit=None, unit_scale=None, unit_divisor=None): + self.total = total + self.disable = disable + self.n = 0 + # ignore unit, unit_scale, unit_divisor; they're just for real tqdm + + def update(self, n): + if self.disable: + return + + self.n += n + if self.total is None: + sys.stderr.write("\r{0:.1f} bytes".format(self.n)) + else: + sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) + sys.stderr.flush() + + def close(self): + self.disable = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disable: + return + + sys.stderr.write('\n') + +# matches bfd8deac from resnet18-bfd8deac.pth +HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') + +_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal") +ENV_GITHUB_TOKEN = 'GITHUB_TOKEN' +ENV_TORCH_HOME = 'TORCH_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' +VAR_DEPENDENCY = 'dependencies' +MODULE_HUBCONF = 'hubconf.py' +READ_DATA_CHUNK = 8192 +_hub_dir = None + + +# Copied from tools/shared/module_loader to be included in torch package +def _import_module(name, path): + import importlib.util + from importlib.abc import Loader + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + assert isinstance(spec.loader, Loader) + spec.loader.exec_module(module) + return module + + +def import_module(name, path): + warnings.warn('The use of torch.hub.import_module is deprecated in v0.11 and will be removed in v0.12', DeprecationWarning) + return _import_module(name, path) + + +def _remove_if_exists(path): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + + +def _git_archive_link(repo_owner, repo_name, branch): + return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch) + + +def _load_attr_from_module(module, func_name): + # Check if callable is defined in the module + if func_name not in dir(module): + return None + return getattr(module, func_name) + + +def _get_torch_home(): + torch_home = os.path.expanduser( + os.getenv(ENV_TORCH_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, + DEFAULT_CACHE_DIR), 'torch'))) + return torch_home + + +def _parse_repo_info(github): + if ':' in github: + repo_info, branch = github.split(':') + else: + repo_info, branch = github, None + repo_owner, repo_name = repo_info.split('/') + + if branch is None: + # The branch wasn't specified by the user, so we need to figure out the + # default branch: main or master. Our assumption is that if main exists + # then it's the default branch, otherwise it's master. + try: + with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): + branch = 'main' + except HTTPError as e: + if e.code == 404: + branch = 'master' + else: + raise + return repo_owner, repo_name, branch + + +def _read_url(url): + with urlopen(url) as r: + return r.read().decode(r.headers.get_content_charset('utf-8')) + + +def _validate_not_a_forked_repo(repo_owner, repo_name, branch): + # Use urlopen to avoid depending on local git. + headers = {'Accept': 'application/vnd.github.v3+json'} + token = os.environ.get(ENV_GITHUB_TOKEN) + if token is not None: + headers['Authorization'] = f'token {token}' + for url_prefix in ( + f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches', + f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'): + page = 0 + while True: + page += 1 + url = f'{url_prefix}?per_page=100&page={page}' + response = json.loads(_read_url(Request(url, headers=headers))) + # Empty response means no more data to process + if not response: + break + for br in response: + if br['name'] == branch or br['commit']['sha'].startswith(branch): + return + + raise ValueError(f'Cannot find {branch} in https://github.com/{repo_owner}/{repo_name}. ' + 'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.') + + + +def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False): + # Setup hub_dir to save downloaded files + hub_dir = get_dir() + if not os.path.exists(hub_dir): + os.makedirs(hub_dir) + # Parse github repo information + repo_owner, repo_name, branch = _parse_repo_info(github) + # Github allows branch name with slash '/', + # this causes confusion with path on both Linux and Windows. + # Backslash is not allowed in Github branch name so no need to + # to worry about it. + normalized_br = branch.replace('/', '_') + # Github renames folder repo-v1.x.x to repo-1.x.x + # We don't know the repo name before downloading the zip file + # and inspect name from it. + # To check if cached repo exists, we need to normalize folder names. + owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br]) + repo_dir = os.path.join(hub_dir, owner_name_branch) + # Check that the repo is in the trusted list + _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn) + + use_cache = (not force_reload) and os.path.exists(repo_dir) + + if use_cache: + if verbose: + sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) + else: + # Validate the tag/branch is from the original repo instead of a forked repo + if not skip_validation: + _validate_not_a_forked_repo(repo_owner, repo_name, branch) + + cached_file = os.path.join(hub_dir, normalized_br + '.zip') + _remove_if_exists(cached_file) + + url = _git_archive_link(repo_owner, repo_name, branch) + sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file)) + download_url_to_file(url, cached_file, progress=False) + + with zipfile.ZipFile(cached_file) as cached_zipfile: + extraced_repo_name = cached_zipfile.infolist()[0].filename + extracted_repo = os.path.join(hub_dir, extraced_repo_name) + _remove_if_exists(extracted_repo) + # Unzip the code and rename the base folder + cached_zipfile.extractall(hub_dir) + + _remove_if_exists(cached_file) + _remove_if_exists(repo_dir) + shutil.move(extracted_repo, repo_dir) # rename the repo + + return repo_dir + +def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"): + hub_dir = get_dir() + filepath = os.path.join(hub_dir, "trusted_list") + + if not os.path.exists(filepath): + Path(filepath).touch() + with open(filepath, 'r') as file: + trusted_repos = tuple(line.strip() for line in file) + + # To minimize friction of introducing the new trust_repo mechanism, we consider that + # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist) + trusted_repos_legacy = next(os.walk(hub_dir))[1] + + owner_name = '_'.join([repo_owner, repo_name]) + is_trusted = ( + owner_name in trusted_repos + or owner_name_branch in trusted_repos_legacy + or repo_owner in _TRUSTED_REPO_OWNERS + ) + + # TODO: Remove `None` option in 1.14 and change the default to "check" + if trust_repo is None: + if not is_trusted: + warnings.warn( + "You are about to download and run code from an untrusted repository. In a future release, this won't " + "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., " + "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " + f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " + f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " + f"confirmation if the repo is not already trusted. This will eventually be the default behaviour") + return + + if (trust_repo is False) or (trust_repo == "check" and not is_trusted): + response = input( + f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " + "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?") + if response.lower() in ("y", "yes"): + if is_trusted: + print("The repository is already trusted.") + elif response.lower() in ("n", "no", ""): + raise Exception("Untrusted repository.") + else: + raise ValueError(f"Unrecognized response {response}.") + + # At this point we're sure that the user trusts the repo (or wants to trust it) + if not is_trusted: + with open(filepath, "a") as file: + file.write(owner_name + "\n") + + +def _check_module_exists(name): + import importlib.util + return importlib.util.find_spec(name) is not None + + +def _check_dependencies(m): + dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) + + if dependencies is not None: + missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] + if len(missing_deps): + raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps))) + + +def _load_entry_from_hubconf(m, model): + if not isinstance(model, str): + raise ValueError('Invalid input: model should be a string of function name') + + # Note that if a missing dependency is imported at top level of hubconf, it will + # throw before this function. It's a chicken and egg situation where we have to + # load hubconf to know what're the dependencies, but to import hubconf it requires + # a missing package. This is fine, Python will throw proper error message for users. + _check_dependencies(m) + + func = _load_attr_from_module(m, model) + + if func is None or not callable(func): + raise RuntimeError('Cannot find callable {} in hubconf'.format(model)) + + return func + + +def get_dir(): + r""" + Get the Torch Hub cache directory used for storing downloaded models & weights. + + If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where + environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. + ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux + filesystem layout, with a default value ``~/.cache`` if the environment + variable is not set. + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_HUB'): + warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') + + if _hub_dir is not None: + return _hub_dir + return os.path.join(_get_torch_home(), 'hub') + + +def set_dir(d): + r""" + Optionally set the Torch Hub directory used to save downloaded models & weights. + + Args: + d (string): path to a local folder to save downloaded models & weights. + """ + global _hub_dir + _hub_dir = d + + +def list(github, force_reload=False, skip_validation=False, trust_repo=None): + r""" + List all callable entrypoints available in the repo specified by ``github``. + + Args: + github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional + tag/branch. If ``tag_name`` is not specified, the default branch is assumed to be ``main`` if + it exists, and otherwise ``master``. + Example: 'pytorch/vision:0.10' + force_reload (bool, optional): whether to discard the existing cache and force a fresh download. + Default is ``False``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + trust_repo (bool, string or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter helps ensuring that users only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v1.14. + + Default is ``None`` and will eventually change to ``"check"`` in a future version. + + Returns: + list: The available callables entrypoint + + Example: + >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) + """ + repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=True, + skip_validation=skip_validation) + + sys.path.insert(0, repo_dir) + + hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + sys.path.remove(repo_dir) + + # We take functions starts with '_' as internal helper functions + entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] + + return entrypoints + + +def help(github, model, force_reload=False, skip_validation=False, trust_repo=None): + r""" + Show the docstring of entrypoint ``model``. + + Args: + github (string): a string with format with an optional + tag/branch. If ``tag_name`` is not specified, the default branch is assumed to be ``main`` if + it exists, and otherwise ``master``. + Example: 'pytorch/vision:0.10' + model (string): a string of entrypoint name defined in repo's ``hubconf.py`` + force_reload (bool, optional): whether to discard the existing cache and force a fresh download. + Default is ``False``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + trust_repo (bool, string or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter helps ensuring that users only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v1.14. + + Default is ``None`` and will eventually change to ``"check"`` in a future version. + Example: + >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) + """ + repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True, + skip_validation=skip_validation) + + sys.path.insert(0, repo_dir) + + hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + sys.path.remove(repo_dir) + + entry = _load_entry_from_hubconf(hub_module, model) + + return entry.__doc__ + + +def load(repo_or_dir, model, *args, source='github', trust_repo=_DEFAULT_SECURITY, force_reload=False, verbose=True, + skip_validation=False, + **kwargs): + r""" + Load a model from a github repo or a local directory. + + Note: Loading a model is the typical use case, but this can also be used to + for loading other objects such as tokenizers, loss functions, etc. + + If ``source`` is 'github', ``repo_or_dir`` is expected to be + of the form ``repo_owner/repo_name[:tag_name]`` with an optional + tag/branch. + + If ``source`` is 'local', ``repo_or_dir`` is expected to be a + path to a local directory. + + Args: + repo_or_dir (string): If ``source`` is 'github', + this should correspond to a github repo with format ``repo_owner/repo_name[:tag_name]`` with + an optional tag/branch, for example 'pytorch/vision:0.10'. If ``tag_name`` is not specified, + the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. + If ``source`` is 'local' then it should be a path to a local directory. + model (string): the name of a callable (entrypoint) defined in the + repo/dir's ``hubconf.py``. + *args (optional): the corresponding args for callable ``model``. + source (string, optional): 'github' or 'local'. Specifies how + ``repo_or_dir`` is to be interpreted. Default is 'github'. + trust_repo (bool, string or None): ``"check"``, ``True``, ``False`` or ``None``. + This parameter helps ensuring that users only run code from repos that they trust. + + - If ``False``, a prompt will ask the user whether the repo should + be trusted. + - If ``True``, the repo will be added to the trusted list and loaded + without requiring explicit confirmation. + - If ``"check"``, the repo will be checked against the list of + trusted repos in the cache. If it is not present in that list, the + behaviour will fall back onto the ``trust_repo=False`` option. + - If ``None``: this will raise a warning, inviting the user to set + ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This + is only present for backward compatibility and will be removed in + v1.14. + + Default is ``None`` and will eventually change to ``"check"`` in a future version. + force_reload (bool, optional): whether to force a fresh download of + the github repo unconditionally. Does not have any effect if + ``source = 'local'``. Default is ``False``. + verbose (bool, optional): If ``False``, mute messages about hitting + local caches. Note that the message about first download cannot be + muted. Does not have any effect if ``source = 'local'``. + Default is ``True``. + skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit + specified by the ``github`` argument properly belongs to the repo owner. This will make + requests to the GitHub API; you can specify a non-default GitHub token by setting the + ``GITHUB_TOKEN`` environment variable. Default is ``False``. + **kwargs (optional): the corresponding kwargs for callable ``model``. + + Returns: + The output of the ``model`` callable when called with the given + ``*args`` and ``**kwargs``. + + Example: + >>> # from a github repo + >>> repo = 'pytorch/vision' + >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) + >>> # from a local directory + >>> path = '/some/local/path/pytorch/vision' + >>> model = torch.hub.load(path, 'resnet50', pretrained=True) + """ + source = source.lower() + + if source not in ('github', 'local'): + raise ValueError( + f'Unknown source: "{source}". Allowed values: "github" | "local".') + + if source == 'github': + repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load", + verbose=verbose, skip_validation=skip_validation) + + model = _load_local(repo_or_dir, model, *args, **kwargs) + return model + + +def _load_local(hubconf_dir, model, *args, **kwargs): + r""" + Load a model from a local directory with a ``hubconf.py``. + + Args: + hubconf_dir (string): path to a local directory that contains a + ``hubconf.py``. + model (string): name of an entrypoint defined in the directory's + ``hubconf.py``. + *args (optional): the corresponding args for callable ``model``. + **kwargs (optional): the corresponding kwargs for callable ``model``. + + Returns: + a single model with corresponding pretrained weights. + + Example: + >>> path = '/some/local/path/pytorch/vision' + >>> model = _load_local(path, 'resnet50', pretrained=True) + """ + sys.path.insert(0, hubconf_dir) + + hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) + hub_module = _import_module(MODULE_HUBCONF, hubconf_path) + + entry = _load_entry_from_hubconf(hub_module, model) + model = entry(*args, **kwargs) + + sys.path.remove(hubconf_dir) + + return model + + +def download_url_to_file(url, dst, hash_prefix=None, progress=True): + r"""Download object at the given URL to a local path. + + Args: + url (string): URL of the object to download + dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file`` + hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. + Default: None + progress (bool, optional): whether or not to display a progress bar to stderr + Default: True + + Example: + >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') + + """ + file_size = None + req = Request(url, headers={"User-Agent": "torch.hub"}) + u = urlopen(req) + meta = u.info() + if hasattr(meta, 'getheaders'): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + # We deliberately save it in a temp file and move it after + # download is complete. This prevents a local working checkpoint + # being overridden by a broken download. + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + try: + if hash_prefix is not None: + sha256 = hashlib.sha256() + with tqdm(total=file_size, disable=not progress, + unit='B', unit_scale=True, unit_divisor=1024) as pbar: + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + if hash_prefix is not None: + sha256.update(buffer) + pbar.update(len(buffer)) + + f.close() + if hash_prefix is not None: + digest = sha256.hexdigest() + if digest[:len(hash_prefix)] != hash_prefix: + raise RuntimeError('invalid hash value (expected "{}", got "{}")' + .format(hash_prefix, digest)) + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def _download_url_to_file(url, dst, hash_prefix=None, progress=True): + warnings.warn('torch.hub._download_url_to_file has been renamed to\ + torch.hub.download_url_to_file to be a public API,\ + _download_url_to_file will be removed in after 1.3 release') + download_url_to_file(url, dst, hash_prefix, progress) + + +# Hub used to support automatically extracts from zipfile manually compressed by users. +# The legacy zip format expects only one file from torch.save() < 1.6 in the zip. +# We should remove this support since zipfile is now default zipfile format for torch.save(). +def _is_legacy_zip_format(filename): + if zipfile.is_zipfile(filename): + infolist = zipfile.ZipFile(filename).infolist() + return len(infolist) == 1 and not infolist[0].is_dir() + return False + + +def _legacy_zip_load(filename, model_dir, map_location): + warnings.warn('Falling back to the old format < 1.6. This support will be ' + 'deprecated in favor of default zipfile format introduced in 1.6. ' + 'Please redo torch.save() to save it in the new zipfile format.') + # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. + # We deliberately don't handle tarfile here since our legacy serialization format was in tar. + # E.g. resnet18-5c106cde.pth which is widely used. + with zipfile.ZipFile(filename) as f: + members = f.infolist() + if len(members) != 1: + raise RuntimeError('Only one file(not dir) is allowed in the zipfile') + f.extractall(model_dir) + extraced_name = members[0].filename + extracted_file = os.path.join(model_dir, extraced_name) + return torch.load(extracted_file, map_location=map_location) + + +def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): + r"""Loads the Torch serialized object at the given URL. + + If downloaded file is a zip file, it will be automatically + decompressed. + + If the object is already present in `model_dir`, it's deserialized and + returned. + The default value of ``model_dir`` is ``/checkpoints`` where + ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + url (string): URL of the object to download + model_dir (string, optional): directory in which to save the object + map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) + progress (bool, optional): whether or not to display a progress bar to stderr. + Default: True + check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. + Default: False + file_name (string, optional): name for the downloaded file. Filename from ``url`` will be used if not set. + + Example: + >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') + + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + # Directory already exists, ignore. + pass + else: + # Unexpected OSError, re-raise. + raise + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + + if _is_legacy_zip_format(cached_file): + return _legacy_zip_load(cached_file, model_dir, map_location) + return torch.load(cached_file, map_location=map_location) diff --git a/scripts/run_pytorch.sh b/scripts/run_pytorch.sh index a4b9132a..f25e38a6 100755 --- a/scripts/run_pytorch.sh +++ b/scripts/run_pytorch.sh @@ -15,7 +15,9 @@ do f_no_ext=${f%.md} # remove .md extension out_py=$PYTHON_CODE_DIR/$f_no_ext.py echo "Extracting Python code from $f into $out_py" - sed -n '/^```python/,/^```/ p' < $f | sed '/^```/ d' > $out_py + echo "import lol_hub" > $out_py + sed -n '/^```python/,/^```/ p' < $f | sed '/^```/ d' | sed 's/torch\.hub/lol_hub/g' >> $out_py done -pytest --junitxml=test-results/junit.xml test_run_python_code.py -vv +mv lol_hub.py $PYTHON_CODE_DIR/lol_hub.py +DEFAULT_SECURITY=check pytest --junitxml=test-results/junit.xml test_run_python_code.py -vv