Skip to content

Commit

Permalink
Fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 10, 2024
1 parent dab5dd6 commit 320c9fa
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
2 changes: 2 additions & 0 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import jax_memory_cleanup
from keras_nlp.utils.preset_utils import list_presets
Expand Down Expand Up @@ -218,6 +219,7 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
check_keras_version()
make_preset_dir(preset)
save_serialized_object(self, preset, config_file=CONFIG_FILE)
save_weights(self, preset, MODEL_WEIGHTS_FILE)
Expand Down
5 changes: 3 additions & 2 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from keras_nlp.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_nlp.models import Tokenizer
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
Expand Down Expand Up @@ -150,7 +150,7 @@ def from_preset(
preset,
config_file=TOKENIZER_CONFIG_FILE,
)
if tokenizer_preset_cls is not cls:
if tokenizer_preset_cls is not cls.tokenizer_cls:
subclasses = list_subclasses(cls)
subclasses = tuple(
filter(
Expand Down Expand Up @@ -192,6 +192,7 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
check_keras_version()
make_preset_dir(preset)
self.tokenizer.save_to_preset(preset)
save_serialized_object(
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from keras_nlp.utils.preset_utils import TASK_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
Expand Down Expand Up @@ -345,6 +346,7 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
check_keras_version()
make_preset_dir(preset)
if self.preprocessor is None:
raise ValueError(
Expand Down
5 changes: 2 additions & 3 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_asset_dir
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_tokenizer
Expand Down Expand Up @@ -143,6 +142,7 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
check_keras_version()
make_preset_dir(preset)
save_tokenizer_assets(self, preset)
save_serialized_object(self, preset, config_file=TOKENIZER_CONFIG_FILE)
Expand Down Expand Up @@ -239,4 +239,3 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
config_file=TOKENIZER_CONFIG_FILE,
asset_dir=TOKENIZER_ASSET_DIR,
)
tokenizer.load_assets(asset_dir)
5 changes: 4 additions & 1 deletion keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,16 @@ def recursive_pop(config, key):
recursive_pop(value, key)


def make_preset_dir(preset):
def check_keras_version():
if not backend_config.keras_3():
raise ValueError(
"`save_to_preset` requires Keras 3. Run `pip install -U keras` "
"upgrade your Keras version, or see https://keras.io/getting_started/ "
"for more info on Keras versions and installation."
)


def make_preset_dir(preset):
os.makedirs(preset, exist_ok=True)


Expand Down

0 comments on commit 320c9fa

Please sign in to comment.