Skip to content

Commit

Permalink
Move saving logic to the base classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 10, 2024
1 parent 740d542 commit 3a500ae
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 69 deletions.
12 changes: 10 additions & 2 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import jax_memory_cleanup
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_metadata
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import save_weights
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -214,7 +218,11 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
save_to_preset(self, preset)
make_preset_dir(preset)
save_serialized_object(self, preset, config_file=CONFIG_FILE)
save_weights(self, preset, MODEL_WEIGHTS_FILE)
save_metadata(self, preset)
# save_to_preset(self, preset)

def enable_lora(self, rank):
"""Enable Lora on the backbone.
Expand Down
13 changes: 11 additions & 2 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_config
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import load_tokenizer
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.python_utils import classproperty

Expand Down Expand Up @@ -167,8 +169,14 @@ def from_preset(
f"Found multiple possible subclasses {names}. "
"Please call `from_preset` on a subclass directly."
)

preprocessor = load_serialized_object(preset, PREPROCESSOR_CONFIG_FILE)
tokenizer_config = load_config(preset, TOKENIZER_CONFIG_FILE)
# TODO: this is not really an override! It's an addition! Should I rename this?
config_overrides = {"tokenizer": tokenizer_config}
preprocessor = load_serialized_object(
preset,
PREPROCESSOR_CONFIG_FILE,
config_overrides=config_overrides,
)
preprocessor.tokenizer = load_tokenizer(
preset,
config_file=TOKENIZER_CONFIG_FILE,
Expand All @@ -183,6 +191,7 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
make_preset_dir(preset)
self.tokenizer.save_to_preset(preset)
save_serialized_object(
self,
Expand Down
25 changes: 19 additions & 6 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import print_msg
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_CONFIG_FILE
Expand All @@ -30,8 +31,11 @@
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_config
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import save_weights
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -284,8 +288,14 @@ def from_preset(
f"which is not a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{task_preset_cls.__name__}` instead."
)

task = load_serialized_object(preset, TASK_CONFIG_FILE)
backbone_config = load_config(preset, CONFIG_FILE)
# TODO: this is not really an override! It's an addition! Should I rename this?
config_overrides = {"backbone": backbone_config}
task = load_serialized_object(
preset,
TASK_CONFIG_FILE,
config_overrides=config_overrides,
)
if load_weights:
task.load_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
Expand Down Expand Up @@ -334,19 +344,22 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
make_preset_dir(preset)
if self.preprocessor is None:
raise ValueError(
"Cannot save `task` to preset: `Preprocessor` is not initialized."
)

self.preprocessor.save_to_preset(preset)
self.backbone.save_to_preset(preset)
save_to_preset(

save_serialized_object(
self,
preset,
config_filename=TASK_CONFIG_FILE,
weights_filename=TASK_WEIGHTS_FILE,
config_file=TASK_CONFIG_FILE,
config_to_skip=["preprocessor", "backbone"],
)
save_weights(self, preset, TASK_WEIGHTS_FILE)

@property
def layers(self):
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_tokenizer
from keras_nlp.utils.preset_utils import save_to_preset
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import save_tokenizer_assets
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -139,7 +141,10 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE)
make_preset_dir(preset)
save_tokenizer_assets(self, preset)
save_serialized_object(self, preset, config_file=TOKENIZER_CONFIG_FILE)
# save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE)

def call(self, inputs, *args, training=None, **kwargs):
return self.tokenize(inputs, *args, **kwargs)
Expand Down
96 changes: 39 additions & 57 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,68 +201,52 @@ def recursive_pop(config, key):
recursive_pop(value, key)


def save_to_preset(
layer,
preset,
save_weights=True,
config_filename=CONFIG_FILE,
weights_filename=MODEL_WEIGHTS_FILE,
):
"""Save a KerasNLP layer to a preset directory."""
if not backend_config.keras_3():
raise ValueError(
"`save_to_preset` requires Keras 3. Run `pip install -U keras` "
"upgrade your Keras version, or see https://keras.io/getting_started/ "
"for more info on Keras versions and installation."
)
def make_preset_dir(preset):
os.makedirs(preset, exist_ok=True)

# Save tokenizers assets.
tokenizer = get_tokenizer(layer)
assets = []

def save_tokenizer_assets(tokenizer, preset):
if tokenizer:
asset_dir = os.path.join(preset, TOKENIZER_ASSET_DIR)
os.makedirs(asset_dir, exist_ok=True)
tokenizer.save_assets(asset_dir)
for asset_path in os.listdir(asset_dir):
assets.append(os.path.join(TOKENIZER_ASSET_DIR, asset_path))

# Optionally save weights.
save_weights = save_weights and hasattr(layer, "save_weights")
if save_weights:
weights_path = os.path.join(preset, weights_filename)
layer.save_weights(weights_path)

# Save a serialized Keras object.
config_path = os.path.join(preset, config_filename)
def save_serialized_object(
layer,
preset,
config_file=CONFIG_FILE,
config_to_skip=[],
):
config_path = os.path.join(preset, config_file)
config = keras.saving.serialize_keras_object(layer)
# Include references to weights and assets.
config["assets"] = assets
config["weights"] = weights_filename if save_weights else None
recursive_pop(config, "compile_config")
recursive_pop(config, "build_config")
# Remove preprocessor and backbone from task.json to prevent redundancy in config files.
if config_filename == TASK_CONFIG_FILE:
recursive_pop(config, "preprocessor")
recursive_pop(config, "backbone")
config_to_skip += ["compile_config", "build_config"]
for c in config_to_skip:
recursive_pop(config, c)
with open(config_path, "w") as config_file:
config_file.write(json.dumps(config, indent=4))


def save_metadata(layer, preset):
from keras_nlp import __version__ as keras_nlp_version

keras_version = keras.version() if hasattr(keras, "version") else None
metadata = {
"keras_version": keras_version,
"keras_nlp_version": keras_nlp_version,
"parameter_count": layer.count_params(),
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
}
metadata_path = os.path.join(preset, "metadata.json")
with open(metadata_path, "w") as metadata_file:
metadata_file.write(json.dumps(metadata, indent=4))


# Save any associated metadata.
if config_filename == CONFIG_FILE:
metadata = {
"keras_version": keras_version,
"keras_nlp_version": keras_nlp_version,
"parameter_count": layer.count_params(),
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
}
metadata_path = os.path.join(preset, "metadata.json")
with open(metadata_path, "w") as metadata_file:
metadata_file.write(json.dumps(metadata, indent=4))
def save_weights(layer, preset, weights_file):
if not hasattr(layer, "save_weights"):
raise ValueError(f"`save_weights` hasn't been defined for `{layer}`.")
weights_path = os.path.join(preset, weights_file)
layer.save_weights(weights_path)


def _validate_tokenizer(preset, allow_incomplete=False):
Expand Down Expand Up @@ -399,23 +383,21 @@ def upload_preset(
)


def load_serialized_object(preset, config_file, config_overrides={}):
def load_config(preset, config_file=CONFIG_FILE):
config_path = get_file(preset, config_file)
with open(config_path) as config_file:
config = json.load(config_file)
config["config"] = {**config["config"], **config_overrides}
return keras.saving.deserialize_keras_object(config)
return config


def save_serialized_object(
layer, preset, config_file=CONFIG_FILE, config_to_skip=[]
def load_serialized_object(
preset,
config_file=CONFIG_FILE,
config_overrides={},
):
config_path = os.path.join(preset, config_file)
config = keras.saving.serialize_keras_object(layer)
for c in config_to_skip:
recursive_pop(config, c)
with open(config_path, "w") as config_file:
config_file.write(json.dumps(config, indent=4))
config = load_config(preset, config_file)
config["config"] = {**config["config"], **config_overrides}
return keras.saving.deserialize_keras_object(config)


def check_config_class(
Expand Down

0 comments on commit 3a500ae

Please sign in to comment.