Skip to content

Commit

Permalink
Move saving logic to the base classes' from_preset.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 10, 2024
1 parent 751fbfa commit dab5dd6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
1 change: 1 addition & 0 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras_nlp.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_nlp.models import Tokenizer
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
Expand Down
5 changes: 3 additions & 2 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.models import Backbone
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import print_msg
from keras_nlp.utils.pipeline_model import PipelineModel
Expand Down Expand Up @@ -236,7 +237,7 @@ def from_preset(
raise ValueError(
f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`."
)
preprocessor = preprocessor_preset_cls.from_preset(preset)
preprocessor = Preprocessor.from_preset(preset)

# Backbone case.
backbone_preset_cls = check_config_class(preset)
Expand Down Expand Up @@ -270,7 +271,7 @@ def from_preset(
config_overrides = {}
if "dtype" in kwargs:
config_overrides["dtype"] = kwargs.pop("dtype")
backbone = backbone_preset_cls.from_preset(
backbone = Backbone.from_preset(
preset,
load_weights=load_weights,
config_overrides=config_overrides,
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_asset_dir
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_tokenizer
Expand Down Expand Up @@ -237,3 +239,4 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
config_file=TOKENIZER_CONFIG_FILE,
asset_dir=TOKENIZER_ASSET_DIR,
)
tokenizer.load_assets(asset_dir)

0 comments on commit dab5dd6

Please sign in to comment.