Skip to content

Commit

Permalink
Improve preprocessor saving and loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 9, 2024
1 parent ed56931 commit 1451d5a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
49 changes: 37 additions & 12 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.layers.preprocessing.preprocessing_layer import (
Expand All @@ -23,11 +20,13 @@
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import load_tokenizer
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -144,6 +143,31 @@ def from_preset(
"config `{PREPROCESSOR_CONFIG_FILE}`."
)

tokenizer_preset_cls = check_config_class(
preset,
config_file=TOKENIZER_CONFIG_FILE,
)
if tokenizer_preset_cls is not cls:
subclasses = list_subclasses(cls)
subclasses = tuple(
filter(
lambda x: x.tokenizer_cls == tokenizer_preset_cls,
subclasses,
)
)
if len(subclasses) == 0:
raise ValueError(
f"No registered subclass of `{cls.__name__}` can load "
f"a `{tokenizer_preset_cls.__name__}`."
)
if len(subclasses) > 1:
names = ", ".join(f"`{x.__name__}`" for x in subclasses)
raise ValueError(
f"Ambiguous call to `{cls.__name__}.from_preset()`. "
f"Found multiple possible subclasses {names}. "
"Please call `from_preset` on a subclass directly."
)

preprocessor = load_serialized_object(preset, PREPROCESSOR_CONFIG_FILE)
preprocessor.tokenizer = load_tokenizer(
preset,
Expand All @@ -154,14 +178,15 @@ def from_preset(
return preprocessor

def save_to_preset(self, preset):
# TODO: Tokenizer config should be dropped from the preprocessor.json because tokenizer has a tokenizer.json config.
"""TODO: add docstring."""
self.tokenizer.save_to_preset(preset)
"""Save preprocessor to a preset directory.
preprocessor_config_path = os.path.join(
preset, PREPROCESSOR_CONFIG_FILE
Args:
preset: The path to the local model preset directory.
"""
self.tokenizer.save_to_preset(preset)
save_serialized_object(
self,
preset,
config_file=PREPROCESSOR_CONFIG_FILE,
config_to_skip=["tokenizer"],
)
preprocessor_config = keras.saving.serialize_keras_object(self)
with open(preprocessor_config_path, "w") as config_file:
config_file.write(json.dumps(preprocessor_config, indent=4))
# TODO: there is overlap between tokenizer.json and preprocessor.json.
9 changes: 3 additions & 6 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,21 +206,18 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
f"a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{preset_cls.__name__}` instead."
)
tokenizer_preset_cls = check_config_class(
preset, config_file=TOKENIZER_CONFIG_FILE
)
if tokenizer_preset_cls is not cls:
if preset_cls is not cls:
subclasses = list_subclasses(cls)
subclasses = tuple(
filter(
lambda x: x.tokenizer_cls == tokenizer_preset_cls,
lambda x: x.tokenizer_cls == preset_cls,
subclasses,
)
)
if len(subclasses) == 0:
raise ValueError(
f"No registered subclass of `{cls.__name__}` can load "
f"a `{tokenizer_preset_cls.__name__}`."
f"a `{preset_cls.__name__}`."
)
if len(subclasses) > 1:
names = ", ".join(f"`{x.__name__}`" for x in subclasses)
Expand Down
17 changes: 14 additions & 3 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,17 @@ def load_serialized_object(preset, config_file, config_overrides={}):
return keras.saving.deserialize_keras_object(config)


def save_serialized_object(
layer, preset, config_file=CONFIG_FILE, config_to_skip=[]
):
config_path = os.path.join(preset, config_file)
config = keras.saving.serialize_keras_object(layer)
for c in config_to_skip:
recursive_pop(config, c)
with open(config_path, "w") as config_file:
config_file.write(json.dumps(config, indent=4))


def check_config_class(
preset,
config_file=CONFIG_FILE,
Expand Down Expand Up @@ -435,11 +446,11 @@ def load_tokenizer(
):
tokenizer = load_serialized_object(preset, config_file)
for asset in tokenizer.file_assets:
get_file(preset, asset)
asset_dir = get_asset_dir(
get_file(preset, os.path.join(asset_dir, asset))
tokenizer_asset_dir = get_asset_dir(
preset,
config_file,
asset_dir,
)
tokenizer.load_assets(asset_dir)
tokenizer.load_assets(tokenizer_asset_dir)
return tokenizer

0 comments on commit 1451d5a

Please sign in to comment.