From 53081a9cece79ca34938df0316b150b540ededc1 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Tue, 2 Apr 2024 23:40:32 +0000 Subject: [PATCH] Fix bugs. --- keras_nlp/models/task.py | 71 ++++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 748b3727b7..7621b7f793 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -275,29 +275,6 @@ def from_preset( f"Received: backbone={kwargs['backbone']}." ) - # Q: Should I move task loading logic to preset_utils.py? - # Load Task from preset. - # TODO: Do we expect remote paths here? os.path doesn't work for remote paths. replacements: https://github.com/keras-team/keras/blob/master/keras/utils/file_utils.py - task_config_path = os.path.join(preset, TASK_CONFIG_FILE) - if issubclass(cls, Task) and os.path.exists(task_config_path): - task_config_class = check_config_class( - preset, config_file=TASK_CONFIG_FILE - ) - if not issubclass(task_config_class, cls): - raise ValueError( - f"`{TASK_CONFIG_FILE}` has type `{task_config_class.__name__}` " - f"which is not a subclass of calling class `{cls.__name__}`. Call " - f"`from_preset` directly on `{task_config_class.__name__}` instead." - ) - - task_config_file = os.path.join(preset, TASK_CONFIG_FILE) - with open(task_config_file, "r") as config_file: - task_config = json.load(config_file) - task = keras.saving.deserialize_keras_object(task_config) - load_weights = load_weights and task_config["weights"] - task_weights_path = os.path.join(preset, task_config["weights"]) - task.load_task_weights(task_weights_path) - # Load backbone from preset. config_path = os.path.join(preset, CONFIG_FILE) if not os.path.exists(config_path): @@ -343,11 +320,39 @@ def from_preset( config_file=TOKENIZER_CONFIG_FILE, ) preprocessor = cls.preprocessor_cls(tokenizer=tokenizer) - if task: - task.backbone = backbone - task.preprocessor = preprocessor - return task - return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) + + # TODO: Do we expect remote paths here? os.path doesn't work for remote paths. replacements: https://github.com/keras-team/keras/blob/master/keras/utils/file_utils.py + task_config_path = os.path.join(preset, TASK_CONFIG_FILE) + if not os.path.exists(task_config_path): + return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) + + # TODO: I should probably move task loading logic to preset_utils.py? + # Load Task from preset. + if not issubclass(cls, Task): + raise ValueError( + "`{cls.__name__}` should be subclass of Task!" + ) # TODO: update error message + task_config_class = check_config_class( + preset, config_file=TASK_CONFIG_FILE + ) + if not issubclass(task_config_class, cls): + raise ValueError( + f"`{TASK_CONFIG_FILE}` has type `{task_config_class.__name__}` " + f"which is not a subclass of calling class `{cls.__name__}`. Call " + f"`from_preset` directly on `{task_config_class.__name__}` instead." + ) + + task_config_file = os.path.join(preset, TASK_CONFIG_FILE) + with open(task_config_file, "r") as config_file: + task_config = json.load(config_file) + # TODO: add back backbone and preprocessor config when save_to_preset removes them (rn, save_to_preset, doesn't remove them!). + # task_config.update(backbone_config) + # task_config.update(preprocessor_config) + task = keras.saving.deserialize_keras_object(task_config) + load_weights = load_weights and task_config["weights"] + task_weights_path = os.path.join(preset, task_config["weights"]) + task.load_task_weights(task_weights_path) + return task def load_task_weights(self, filepath, skip_mismatch=False): """Load only the tasks specific weights not in the backbone.""" @@ -400,6 +405,7 @@ def save_task_weights(self, filepath): ) weights_store.close() + # TODO: do we want to have a `save_weights` flag in this public save_to_preset? probably yes! def save_to_preset(self, preset): """TODO: add docstring""" if self.preprocessor is None: @@ -409,17 +415,20 @@ def save_to_preset(self, preset): self.preprocessor.save_to_preset(preset) self.backbone.save_to_preset(preset) + weights_filename = "task.weights.h5" # TODO: the serialization and saving logic should probably be moved to preset_utils.py task_config_path = os.path.join(preset, TASK_CONFIG_FILE) task_config = keras.saving.serialize_keras_object(self) recursive_pop(task_config, "compile_config") recursive_pop(task_config, "build_config") - recursive_pop(task_config, "preprocessor") - recursive_pop(task_config, "backbone") + # TODO: remove preprocessor and backbone from task.json to prevent redundancy in config files. + # recursive_pop(task_config, "preprocessor") + # recursive_pop(task_config, "backbone") + task_config["weights"] = weights_filename with open(task_config_path, "w") as config_file: config_file.write(json.dumps(task_config, indent=4)) - task_weights_path = os.path.join(preset, "task.weights.h5") + task_weights_path = os.path.join(preset, weights_filename) self.save_task_weights(task_weights_path) @property