Skip to content

Commit

Permalink
Move saving logic to the base classes' from_preset.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 8, 2024
1 parent f6e7b5a commit df88617
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 128 deletions.
17 changes: 11 additions & 6 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import jax_memory_cleanup
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_from_preset
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.python_utils import classproperty

Expand Down Expand Up @@ -197,11 +200,13 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
f"a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{preset_cls.__name__}` instead."
)
return load_from_preset(
preset,
load_weights=load_weights,
config_overrides=kwargs,
)

backbone = load_serialized_object(preset, CONFIG_FILE)
if load_weights:
jax_memory_cleanup(backbone)
backbone.load_weights(get_file(preset, CONFIG_FILE))

return backbone

def save_to_preset(self, preset):
"""Save backbone to a preset directory.
Expand Down
49 changes: 7 additions & 42 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
from keras_nlp.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_nlp.models import Tokenizer
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_file_exists
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -131,55 +134,17 @@ def from_preset(
"`keras_nlp.models.BertPreprocessor.from_preset()`."
)

# TODO: Move this to load_from_preset in preset_utils.py?
# TODO: This loading a config and deserializing the object has been repeated multiple times. Make it into a function.
# TODO: preprocessor.json can have a tokenizer class so we can load the tokenizer like TokenizerClass.from_preset(preset).
# TODO: Tokenizer config should be dropped from the preprocessor.json because tokenizer has a tokenizer.json config.
preprocessor_config_path = os.path.join(
preset, PREPROCESSOR_CONFIG_FILE
)
if not os.path.exists(preprocessor_config_path):
if not check_file_exists(preset, PREPROCESSOR_CONFIG_FILE):
raise FileNotFoundError(
f"preset should contain a `{PREPROCESSOR_CONFIG_FILE}`"
) # TODO: update error message.
with open(preprocessor_config_path) as config_file:
preprocessor_config = json.load(config_file)

preprocessor = keras.saving.deserialize_keras_object(
preprocessor_config
)

# TODO: Check preprocessor class. Preprocessors don't have preprocessor_cls.
# preprocessor_preset_cls = check_config_class(
# preset, config_file=PREPROCESSOR_CONFIG_FILE
# )
# subclasses = list_subclasses(cls)
# subclasses = tuple(
# filter(
# lambda x: x.preprocessor_cls == preprocessor_preset_cls,
# subclasses,
# )
# )
# if len(subclasses) == 0:
# raise ValueError(
# f"No registered subclass of `{cls.__name__}` can load "
# f"a `{preprocessor_preset_cls.__name__}`."
# )
# if len(subclasses) > 1:
# names = ", ".join(f"`{x.__name__}`" for x in subclasses)
# raise ValueError(
# f"Ambiguous call to `{cls.__name__}.from_preset()`. "
# f"Found multiple possible subclasses {names}. "
# "Please call `from_preset` on a subclass directly."
# )
# preprocessor_cls = subclasses[0]

tokenizer_cls = preprocessor_config["config"]["tokenizer"]
preprocessor.tokenizer = tokenizer_cls.from_preset(preset)
preprocessor = load_serialized_object(preset, PREPROCESSOR_CONFIG_FILE)
preprocessor.tokenizer = Tokenizer.from_preset(preset)

return preprocessor

def save_to_preset(self, preset):
# TODO: Tokenizer config should be dropped from the preprocessor.json because tokenizer has a tokenizer.json config.
"""TODO: add docstring."""
self.tokenizer.save_to_preset(preset)

Expand Down
23 changes: 17 additions & 6 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.models import Backbone
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import print_msg
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_task_from_preset
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.python_utils import classproperty

Expand Down Expand Up @@ -281,7 +283,7 @@ def from_preset(
raise ValueError(
f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`."
)
preprocessor = preprocessor_preset_cls.from_preset(preset)
preprocessor = Preprocessor.from_preset(preset)

# Backbone case.
backbone_preset_cls = check_config_class(preset)
Expand Down Expand Up @@ -315,13 +317,15 @@ def from_preset(
config_overrides = {}
if "dtype" in kwargs:
config_overrides["dtype"] = kwargs.pop("dtype")
backbone = backbone_preset_cls.from_preset(
backbone = Backbone.from_preset(
preset,
load_weights=load_weights,
config_overrides=config_overrides,
)
return cls(
backbone=backbone, preprocessor=preprocessor, **kwargs
backbone=backbone,
preprocessor=preprocessor,
**kwargs,
)

# Load task from preset if it exists.
Expand All @@ -332,8 +336,15 @@ def from_preset(
f"`from_preset` directly on `{task_preset_cls.__name__}` instead."
)

task = load_task_from_preset(preset, TASK_CONFIG_FILE)
# TODO: should I avoid duplicating preprocessor memory too?
task = load_serialized_object(
preset,
TASK_CONFIG_FILE,
config_overrides,
)
if load_weights:
task.load_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.backbone.load_weights(get_file(preset, CONFIG_FILE))
# TODO: is task.preprocessor None before this assignment?
task.preprocessor = preprocessor
return task

Expand Down
1 change: 1 addition & 0 deletions keras_nlp/tokenizers/byte_pair_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def __init__(
self.sequence_length = sequence_length
self.add_prefix_space = add_prefix_space
self.unsplittable_tokens = unsplittable_tokens
self.file_assets = [VOCAB_FILENAME, MERGES_FILENAME]

# Create byte <=> unicode mapping. This is useful for handling
# whitespace tokens.
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/tokenizers/sentence_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
self.proto = None
self.sequence_length = sequence_length
self.set_proto(proto)
self.file_assets = [VOCAB_FILENAME]

def save_assets(self, dir_path):
path = os.path.join(dir_path, VOCAB_FILENAME)
Expand Down
18 changes: 14 additions & 4 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
from keras_nlp.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_asset_dir
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_from_preset
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.python_utils import classproperty

Expand Down Expand Up @@ -75,6 +78,7 @@ def detokenize(self, inputs):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.file_assets = None

def tokenize(self, inputs, *args, **kwargs):
"""Transform input tensors of strings into output tokens.
Expand Down Expand Up @@ -228,8 +232,14 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
"Please call `from_preset` on a subclass directly."
)

return load_from_preset(
tokenizer = load_serialized_object(preset, TOKENIZER_CONFIG_FILE)

# Ensure all the assets exist.
for asset in tokenizer_preset_cls.file_assets:
get_file(preset, asset)
asset_dir = get_asset_dir(
preset,
config_file=TOKENIZER_CONFIG_FILE,
config_overrides=kwargs,
TOKENIZER_CONFIG_FILE,
TOKENIZER_ASSET_DIR,
)
tokenizer.load_assets(asset_dir)
1 change: 1 addition & 0 deletions keras_nlp/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def __init__(
self.special_tokens
)
self.set_vocabulary(vocabulary)
self.file_assets = [VOCAB_FILENAME]

def save_assets(self, dir_path):
path = os.path.join(dir_path, VOCAB_FILENAME)
Expand Down
95 changes: 25 additions & 70 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,80 +368,12 @@ def upload_preset(
)


def load_from_preset(
preset,
load_weights=True,
config_file=CONFIG_FILE,
config_overrides={},
):
"""Load a KerasNLP layer to a preset directory."""
# Load a serialized Keras object.
def load_serialized_object(preset, config_file, config_overrides={}):
config_path = get_file(preset, config_file)
with open(config_path) as config_file:
config = json.load(config_file)
config["config"] = {**config["config"], **config_overrides}
layer = keras.saving.deserialize_keras_object(config)

# Load any assets for our tokenizers.
tokenizer = get_tokenizer(layer)
if tokenizer and config["assets"]:
for asset in config["assets"]:
get_file(preset, asset)
config_dir = os.path.dirname(config_path)
asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR)
tokenizer.load_assets(asset_dir)

# Optionally load weights.
load_weights = load_weights and config["weights"]
if load_weights:
# For jax, delete all previous allocated memory to avoid temporarily
# duplicating variable allocations. torch and tensorflow have stateful
# variable types and do not need this fix.
if backend_config.backend() == "jax":
for weight in layer.weights:
if getattr(weight, "_value", None) is not None:
weight._value.delete()
weights_path = get_file(preset, config["weights"])
layer.load_weights(weights_path)

return layer


def load_task_from_preset(
preset,
load_weights=True,
task_config_file=TASK_CONFIG_FILE,
backbone_config_file=CONFIG_FILE,
config_overrides={},
):
# Load a serialized Keras object.
task_config = _get_config(preset, task_config_file)
task_config["config"] = {**task_config["config"], **config_overrides}
task = keras.saving.deserialize_keras_object(task_config)
backbone_config = _get_config(preset, backbone_config_file)
if load_weights:
if not task_config["weights"]:
raise ValueError(
f"`weights` config is missing from `{task_config_file}` in "
f"preset directory `{preset}`."
)
if not backbone_config["weights"]:
raise ValueError(
f"`weights` config is missing from `{backbone_config_file}` in "
f"preset directory `{preset}`."
)
task_weights_path = os.path.join(preset, task_config["weights"])
task.load_weights(task_weights_path)
backbone_weights_path = os.path.join(preset, backbone_config["weights"])
task.backbone.load_weights(backbone_weights_path)
return task


def _get_config(preset, config_file):
config_path = get_file(preset, config_file)
with open(config_path) as config_file:
config = json.load(config_file)
return config
return keras.saving.deserialize_keras_object(config)


def check_config_class(
Expand All @@ -453,3 +385,26 @@ def check_config_class(
with open(config_path) as config_file:
config = json.load(config_file)
return keras.saving.get_registered_object(config["registered_name"])


def get_asset_dir(
preset, config_file=TOKENIZER_CONFIG_FILE, asset_dir=TOKENIZER_ASSET_DIR
):
config_path = get_file(preset, config_file)
config_dir = os.path.dirname(config_path)
return os.path.join(config_dir, asset_dir)


def check_file_exists(preset, config_file):
# TODO: implement this.
return True


def jax_memory_cleanup(layer):
# For jax, delete all previous allocated memory to avoid temporarily
# duplicating variable allocations. torch and tensorflow have stateful
# variable types and do not need this fix.
if backend_config.backend() == "jax":
for weight in layer.weights:
if getattr(weight, "_value", None) is not None:
weight._value.delete()

0 comments on commit df88617

Please sign in to comment.