From ed56931cda3fd55297e354e56a7f87a3390fe254 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Tue, 9 Apr 2024 22:50:17 +0000 Subject: [PATCH] Raise FileNotFound exception in get_file if the file doesn't exist. --- keras_nlp/utils/preset_utils.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index b77ea89cc0..5d9005fe24 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -26,11 +26,13 @@ try: import kagglehub + from kagglehub.exceptions import KaggleApiHTTPError except ImportError: kagglehub = None try: import huggingface_hub + from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import HFValidationError except ImportError: huggingface_hub = None @@ -85,7 +87,7 @@ def list_subclasses(cls): def get_file(preset, path): """Download a preset file in necessary and return the local path.""" - # TODO: Through FileNotFoundError when the path doesn't exist. + # TODO: Add tests for FileNotFound exceptions. if not isinstance(preset, str): raise ValueError( f"A preset identifier must be a string. Received: preset={preset}" @@ -108,7 +110,17 @@ def get_file(preset, path): "'kaggle://username/bert/keras/bert_base_en/1' (to specify a " f"version). Received: preset={preset}" ) - return kagglehub.model_download(kaggle_handle, path) + try: + return kagglehub.model_download(kaggle_handle, path) + except KaggleApiHTTPError as e: + message = str(e) + if message.find("403 Client Error"): + raise FileNotFoundError( + f"`{path}` doesn't exist in preset directory `{preset}`.\n" + ) + else: + raise ValueError(message) + elif preset.startswith(GS_PREFIX): url = os.path.join(preset, path) url = url.replace(GS_PREFIX, "https://storage.googleapis.com/") @@ -138,9 +150,22 @@ def get_file(preset, path): "should have the form 'hf://{org}/{model}'. For example, " f"'hf://username/bert_base_en'. Received: preset={preset}." ) from e + except EntryNotFoundError as e: + message = str(e) + if message.find("403 Client Error"): + raise FileNotFoundError( + f"`{path}` doesn't exist in preset directory `{preset}`.\n" + ) + else: + raise ValueError(message) elif os.path.exists(preset): # Assume a local filepath. - return os.path.join(preset, path) + local_path = os.path.join(preset, path) + if not os.path.exists(local_path): + raise FileNotFoundError( + f"`{path}` doesn't exist in preset directory `{preset}`.\n" + ) + return local_path else: raise ValueError( "Unknown preset identifier. A preset must be a one of:\n"