diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 696e671ed3..962003d0f6 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -16,13 +16,17 @@ from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.utils.preset_utils import CONFIG_FILE +from keras_nlp.utils.preset_utils import MODEL_WEIGHTS_FILE from keras_nlp.utils.preset_utils import check_config_class from keras_nlp.utils.preset_utils import get_file from keras_nlp.utils.preset_utils import jax_memory_cleanup from keras_nlp.utils.preset_utils import list_presets from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_serialized_object -from keras_nlp.utils.preset_utils import save_to_preset +from keras_nlp.utils.preset_utils import make_preset_dir +from keras_nlp.utils.preset_utils import save_metadata +from keras_nlp.utils.preset_utils import save_serialized_object +from keras_nlp.utils.preset_utils import save_weights from keras_nlp.utils.python_utils import classproperty @@ -214,7 +218,11 @@ def save_to_preset(self, preset): Args: preset: The path to the local model preset directory. """ - save_to_preset(self, preset) + make_preset_dir(preset) + save_serialized_object(self, preset, config_file=CONFIG_FILE) + save_weights(self, preset, MODEL_WEIGHTS_FILE) + save_metadata(self, preset) + # save_to_preset(self, preset) def enable_lora(self, rank): """Enable Lora on the backbone. diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index e061b007ad..fc6509954e 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -24,8 +24,10 @@ from keras_nlp.utils.preset_utils import get_file from keras_nlp.utils.preset_utils import list_presets from keras_nlp.utils.preset_utils import list_subclasses +from keras_nlp.utils.preset_utils import load_config from keras_nlp.utils.preset_utils import load_serialized_object from keras_nlp.utils.preset_utils import load_tokenizer +from keras_nlp.utils.preset_utils import make_preset_dir from keras_nlp.utils.preset_utils import save_serialized_object from keras_nlp.utils.python_utils import classproperty @@ -167,8 +169,14 @@ def from_preset( f"Found multiple possible subclasses {names}. " "Please call `from_preset` on a subclass directly." ) - - preprocessor = load_serialized_object(preset, PREPROCESSOR_CONFIG_FILE) + tokenizer_config = load_config(preset, TOKENIZER_CONFIG_FILE) + # TODO: this is not really an override! It's an addition! Should I rename this? + config_overrides = {"tokenizer": tokenizer_config} + preprocessor = load_serialized_object( + preset, + PREPROCESSOR_CONFIG_FILE, + config_overrides=config_overrides, + ) preprocessor.tokenizer = load_tokenizer( preset, config_file=TOKENIZER_CONFIG_FILE, @@ -183,6 +191,7 @@ def save_to_preset(self, preset): Args: preset: The path to the local model preset directory. """ + make_preset_dir(preset) self.tokenizer.save_to_preset(preset) save_serialized_object( self, diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index f82c9e03ee..074845b104 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -22,6 +22,7 @@ from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import print_msg from keras_nlp.utils.pipeline_model import PipelineModel +from keras_nlp.utils.preset_utils import CONFIG_FILE from keras_nlp.utils.preset_utils import MODEL_WEIGHTS_FILE from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE from keras_nlp.utils.preset_utils import TASK_CONFIG_FILE @@ -30,8 +31,11 @@ from keras_nlp.utils.preset_utils import get_file from keras_nlp.utils.preset_utils import list_presets from keras_nlp.utils.preset_utils import list_subclasses +from keras_nlp.utils.preset_utils import load_config from keras_nlp.utils.preset_utils import load_serialized_object -from keras_nlp.utils.preset_utils import save_to_preset +from keras_nlp.utils.preset_utils import make_preset_dir +from keras_nlp.utils.preset_utils import save_serialized_object +from keras_nlp.utils.preset_utils import save_weights from keras_nlp.utils.python_utils import classproperty @@ -334,8 +338,14 @@ def from_preset( f"which is not a subclass of calling class `{cls.__name__}`. Call " f"`from_preset` directly on `{task_preset_cls.__name__}` instead." ) - - task = load_serialized_object(preset, TASK_CONFIG_FILE) + backbone_config = load_config(preset, CONFIG_FILE) + # TODO: this is not really an override! It's an addition! Should I rename this? + config_overrides = {"backbone": backbone_config} + task = load_serialized_object( + preset, + TASK_CONFIG_FILE, + config_overrides=config_overrides, + ) if load_weights: task.load_weights(get_file(preset, TASK_WEIGHTS_FILE)) task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE)) @@ -384,6 +394,7 @@ def save_to_preset(self, preset): Args: preset: The path to the local model preset directory. """ + make_preset_dir(preset) if self.preprocessor is None: raise ValueError( "Cannot save `task` to preset: `Preprocessor` is not initialized." @@ -391,12 +402,14 @@ def save_to_preset(self, preset): self.preprocessor.save_to_preset(preset) self.backbone.save_to_preset(preset) - save_to_preset( + + save_serialized_object( self, preset, - config_filename=TASK_CONFIG_FILE, - weights_filename=TASK_WEIGHTS_FILE, + config_file=TASK_CONFIG_FILE, + config_to_skip=["preprocessor", "backbone"], ) + save_weights(self, preset, TASK_WEIGHTS_FILE) @property def layers(self): diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 4c1e2c6150..55c5f54d77 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -22,7 +22,9 @@ from keras_nlp.utils.preset_utils import list_presets from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_tokenizer -from keras_nlp.utils.preset_utils import save_to_preset +from keras_nlp.utils.preset_utils import make_preset_dir +from keras_nlp.utils.preset_utils import save_serialized_object +from keras_nlp.utils.preset_utils import save_tokenizer_assets from keras_nlp.utils.python_utils import classproperty @@ -139,7 +141,10 @@ def save_to_preset(self, preset): Args: preset: The path to the local model preset directory. """ - save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE) + make_preset_dir(preset) + save_tokenizer_assets(self, preset) + save_serialized_object(self, preset, config_file=TOKENIZER_CONFIG_FILE) + # save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE) def call(self, inputs, *args, training=None, **kwargs): return self.tokenize(inputs, *args, **kwargs) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index d2d9493577..b3a47cb405 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -201,62 +201,52 @@ def recursive_pop(config, key): recursive_pop(value, key) -def save_to_preset( - layer, - preset, - save_weights=True, - config_filename=CONFIG_FILE, - weights_filename=MODEL_WEIGHTS_FILE, -): - """Save a KerasNLP layer to a preset directory.""" +def make_preset_dir(preset): os.makedirs(preset, exist_ok=True) - # Save tokenizers assets. - tokenizer = get_tokenizer(layer) - assets = [] + +def save_tokenizer_assets(tokenizer, preset): if tokenizer: asset_dir = os.path.join(preset, TOKENIZER_ASSET_DIR) os.makedirs(asset_dir, exist_ok=True) tokenizer.save_assets(asset_dir) - for asset_path in os.listdir(asset_dir): - assets.append(os.path.join(TOKENIZER_ASSET_DIR, asset_path)) - # Optionally save weights. - save_weights = save_weights and hasattr(layer, "save_weights") - if save_weights: - weights_path = os.path.join(preset, weights_filename) - layer.save_weights(weights_path) - # Save a serialized Keras object. - config_path = os.path.join(preset, config_filename) +def save_serialized_object( + layer, + preset, + config_file=CONFIG_FILE, + config_to_skip=[], +): + config_path = os.path.join(preset, config_file) config = keras.saving.serialize_keras_object(layer) - # Include references to weights and assets. - config["assets"] = assets - config["weights"] = weights_filename if save_weights else None - recursive_pop(config, "compile_config") - recursive_pop(config, "build_config") - # Remove preprocessor and backbone from task.json to prevent redundancy in config files. - if config_filename == TASK_CONFIG_FILE: - recursive_pop(config, "preprocessor") - recursive_pop(config, "backbone") + config_to_skip += ["compile_config", "build_config"] + for c in config_to_skip: + recursive_pop(config, c) with open(config_path, "w") as config_file: config_file.write(json.dumps(config, indent=4)) + +def save_metadata(layer, preset): from keras_nlp import __version__ as keras_nlp_version keras_version = keras.version() if hasattr(keras, "version") else None + metadata = { + "keras_version": keras_version, + "keras_nlp_version": keras_nlp_version, + "parameter_count": layer.count_params(), + "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), + } + metadata_path = os.path.join(preset, "metadata.json") + with open(metadata_path, "w") as metadata_file: + metadata_file.write(json.dumps(metadata, indent=4)) - # Save any associated metadata. - if config_filename == CONFIG_FILE: - metadata = { - "keras_version": keras_version, - "keras_nlp_version": keras_nlp_version, - "parameter_count": layer.count_params(), - "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), - } - metadata_path = os.path.join(preset, "metadata.json") - with open(metadata_path, "w") as metadata_file: - metadata_file.write(json.dumps(metadata, indent=4)) + +def save_weights(layer, preset, weights_file): + if not hasattr(layer, "save_weights"): + raise ValueError(f"`save_weights` hasn't been defined for `{layer}`.") + weights_path = os.path.join(preset, weights_file) + layer.save_weights(weights_path) def _validate_tokenizer(preset, allow_incomplete=False): @@ -393,23 +383,21 @@ def upload_preset( ) -def load_serialized_object(preset, config_file, config_overrides={}): +def load_config(preset, config_file=CONFIG_FILE): config_path = get_file(preset, config_file) with open(config_path) as config_file: config = json.load(config_file) - config["config"] = {**config["config"], **config_overrides} - return keras.saving.deserialize_keras_object(config) + return config -def save_serialized_object( - layer, preset, config_file=CONFIG_FILE, config_to_skip=[] +def load_serialized_object( + preset, + config_file=CONFIG_FILE, + config_overrides={}, ): - config_path = os.path.join(preset, config_file) - config = keras.saving.serialize_keras_object(layer) - for c in config_to_skip: - recursive_pop(config, c) - with open(config_path, "w") as config_file: - config_file.write(json.dumps(config, indent=4)) + config = load_config(preset, config_file) + config["config"] = {**config["config"], **config_overrides} + return keras.saving.deserialize_keras_object(config) def check_config_class(