Skip to content

Commit

Permalink
Re-design loading from preset.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 9, 2024
1 parent df88617 commit a01c94b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
10 changes: 8 additions & 2 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from keras_nlp.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_nlp.models import Tokenizer
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_file_exists
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import load_tokenizer
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -139,7 +141,11 @@ def from_preset(
f"preset should contain a `{PREPROCESSOR_CONFIG_FILE}`"
) # TODO: update error message.
preprocessor = load_serialized_object(preset, PREPROCESSOR_CONFIG_FILE)
preprocessor.tokenizer = Tokenizer.from_preset(preset)
preprocessor.tokenizer = load_tokenizer(
preset,
config_file=TOKENIZER_CONFIG_FILE,
asset_dir=TOKENIZER_ASSET_DIR,
)

return preprocessor

Expand Down
15 changes: 5 additions & 10 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.models import Backbone
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import print_msg
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_WEIGHTS_FILE
Expand Down Expand Up @@ -283,7 +282,7 @@ def from_preset(
raise ValueError(
f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`."
)
preprocessor = Preprocessor.from_preset(preset)
preprocessor = preprocessor_preset_cls.from_preset(preset)

# Backbone case.
backbone_preset_cls = check_config_class(preset)
Expand Down Expand Up @@ -317,7 +316,7 @@ def from_preset(
config_overrides = {}
if "dtype" in kwargs:
config_overrides["dtype"] = kwargs.pop("dtype")
backbone = Backbone.from_preset(
backbone = backbone_preset_cls.from_preset(
preset,
load_weights=load_weights,
config_overrides=config_overrides,
Expand All @@ -336,14 +335,10 @@ def from_preset(
f"`from_preset` directly on `{task_preset_cls.__name__}` instead."
)

task = load_serialized_object(
preset,
TASK_CONFIG_FILE,
config_overrides,
)
task = load_serialized_object(preset, TASK_CONFIG_FILE)
if load_weights:
task.load_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.backbone.load_weights(get_file(preset, CONFIG_FILE))
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
# TODO: is task.preprocessor None before this assignment?
task.preprocessor = preprocessor
return task
Expand Down
16 changes: 4 additions & 12 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_asset_dir
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import load_tokenizer
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.python_utils import classproperty

Expand Down Expand Up @@ -232,14 +230,8 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
"Please call `from_preset` on a subclass directly."
)

tokenizer = load_serialized_object(preset, TOKENIZER_CONFIG_FILE)

# Ensure all the assets exist.
for asset in tokenizer_preset_cls.file_assets:
get_file(preset, asset)
asset_dir = get_asset_dir(
return load_tokenizer(
preset,
TOKENIZER_CONFIG_FILE,
TOKENIZER_ASSET_DIR,
config_file=TOKENIZER_CONFIG_FILE,
asset_dir=TOKENIZER_ASSET_DIR,
)
tokenizer.load_assets(asset_dir)
15 changes: 15 additions & 0 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,18 @@ def jax_memory_cleanup(layer):
for weight in layer.weights:
if getattr(weight, "_value", None) is not None:
weight._value.delete()


def load_tokenizer(
preset, config_file=TOKENIZER_CONFIG_FILE, asset_dir=TOKENIZER_ASSET_DIR
):
tokenizer = load_serialized_object(preset, config_file)
for asset in tokenizer.file_assets:
get_file(preset, asset)
asset_dir = get_asset_dir(
preset,
config_file,
asset_dir,
)
tokenizer.load_assets(asset_dir)
return tokenizer

0 comments on commit a01c94b

Please sign in to comment.