Skip to content

Commit

Permalink
Raise FileNotFound exception in get_file if the file doesn't exist.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 9, 2024
1 parent 751fa5d commit ed56931
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@

try:
import kagglehub
from kagglehub.exceptions import KaggleApiHTTPError
except ImportError:
kagglehub = None

try:
import huggingface_hub
from huggingface_hub.utils import EntryNotFoundError
from huggingface_hub.utils import HFValidationError
except ImportError:
huggingface_hub = None
Expand Down Expand Up @@ -85,7 +87,7 @@ def list_subclasses(cls):

def get_file(preset, path):
"""Download a preset file in necessary and return the local path."""
# TODO: Through FileNotFoundError when the path doesn't exist.
# TODO: Add tests for FileNotFound exceptions.
if not isinstance(preset, str):
raise ValueError(
f"A preset identifier must be a string. Received: preset={preset}"
Expand All @@ -108,7 +110,17 @@ def get_file(preset, path):
"'kaggle://username/bert/keras/bert_base_en/1' (to specify a "
f"version). Received: preset={preset}"
)
return kagglehub.model_download(kaggle_handle, path)
try:
return kagglehub.model_download(kaggle_handle, path)
except KaggleApiHTTPError as e:
message = str(e)
if message.find("403 Client Error"):
raise FileNotFoundError(
f"`{path}` doesn't exist in preset directory `{preset}`.\n"
)
else:
raise ValueError(message)

elif preset.startswith(GS_PREFIX):
url = os.path.join(preset, path)
url = url.replace(GS_PREFIX, "https://storage.googleapis.com/")
Expand Down Expand Up @@ -138,9 +150,22 @@ def get_file(preset, path):
"should have the form 'hf://{org}/{model}'. For example, "
f"'hf://username/bert_base_en'. Received: preset={preset}."
) from e
except EntryNotFoundError as e:
message = str(e)
if message.find("403 Client Error"):
raise FileNotFoundError(
f"`{path}` doesn't exist in preset directory `{preset}`.\n"
)
else:
raise ValueError(message)
elif os.path.exists(preset):
# Assume a local filepath.
return os.path.join(preset, path)
local_path = os.path.join(preset, path)
if not os.path.exists(local_path):
raise FileNotFoundError(
f"`{path}` doesn't exist in preset directory `{preset}`.\n"
)
return local_path
else:
raise ValueError(
"Unknown preset identifier. A preset must be a one of:\n"
Expand Down

0 comments on commit ed56931

Please sign in to comment.