Skip to content

Commit

Permalink
Fix bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 2, 2024
1 parent 1df3e08 commit 53081a9
Showing 1 changed file with 40 additions and 31 deletions.
71 changes: 40 additions & 31 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,29 +275,6 @@ def from_preset(
f"Received: backbone={kwargs['backbone']}."
)

# Q: Should I move task loading logic to preset_utils.py?
# Load Task from preset.
# TODO: Do we expect remote paths here? os.path doesn't work for remote paths. replacements: https://github.com/keras-team/keras/blob/master/keras/utils/file_utils.py
task_config_path = os.path.join(preset, TASK_CONFIG_FILE)
if issubclass(cls, Task) and os.path.exists(task_config_path):
task_config_class = check_config_class(
preset, config_file=TASK_CONFIG_FILE
)
if not issubclass(task_config_class, cls):
raise ValueError(
f"`{TASK_CONFIG_FILE}` has type `{task_config_class.__name__}` "
f"which is not a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{task_config_class.__name__}` instead."
)

task_config_file = os.path.join(preset, TASK_CONFIG_FILE)
with open(task_config_file, "r") as config_file:
task_config = json.load(config_file)
task = keras.saving.deserialize_keras_object(task_config)
load_weights = load_weights and task_config["weights"]
task_weights_path = os.path.join(preset, task_config["weights"])
task.load_task_weights(task_weights_path)

# Load backbone from preset.
config_path = os.path.join(preset, CONFIG_FILE)
if not os.path.exists(config_path):
Expand Down Expand Up @@ -343,11 +320,39 @@ def from_preset(
config_file=TOKENIZER_CONFIG_FILE,
)
preprocessor = cls.preprocessor_cls(tokenizer=tokenizer)
if task:
task.backbone = backbone
task.preprocessor = preprocessor
return task
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

# TODO: Do we expect remote paths here? os.path doesn't work for remote paths. replacements: https://github.com/keras-team/keras/blob/master/keras/utils/file_utils.py
task_config_path = os.path.join(preset, TASK_CONFIG_FILE)
if not os.path.exists(task_config_path):
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

# TODO: I should probably move task loading logic to preset_utils.py?
# Load Task from preset.
if not issubclass(cls, Task):
raise ValueError(
"`{cls.__name__}` should be subclass of Task!"
) # TODO: update error message
task_config_class = check_config_class(
preset, config_file=TASK_CONFIG_FILE
)
if not issubclass(task_config_class, cls):
raise ValueError(
f"`{TASK_CONFIG_FILE}` has type `{task_config_class.__name__}` "
f"which is not a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{task_config_class.__name__}` instead."
)

task_config_file = os.path.join(preset, TASK_CONFIG_FILE)
with open(task_config_file, "r") as config_file:
task_config = json.load(config_file)
# TODO: add back backbone and preprocessor config when save_to_preset removes them (rn, save_to_preset, doesn't remove them!).
# task_config.update(backbone_config)
# task_config.update(preprocessor_config)
task = keras.saving.deserialize_keras_object(task_config)
load_weights = load_weights and task_config["weights"]
task_weights_path = os.path.join(preset, task_config["weights"])
task.load_task_weights(task_weights_path)
return task

def load_task_weights(self, filepath, skip_mismatch=False):
"""Load only the tasks specific weights not in the backbone."""
Expand Down Expand Up @@ -400,6 +405,7 @@ def save_task_weights(self, filepath):
)
weights_store.close()

# TODO: do we want to have a `save_weights` flag in this public save_to_preset? probably yes!
def save_to_preset(self, preset):
"""TODO: add docstring"""
if self.preprocessor is None:
Expand All @@ -409,17 +415,20 @@ def save_to_preset(self, preset):

self.preprocessor.save_to_preset(preset)
self.backbone.save_to_preset(preset)
weights_filename = "task.weights.h5"

# TODO: the serialization and saving logic should probably be moved to preset_utils.py
task_config_path = os.path.join(preset, TASK_CONFIG_FILE)
task_config = keras.saving.serialize_keras_object(self)
recursive_pop(task_config, "compile_config")
recursive_pop(task_config, "build_config")
recursive_pop(task_config, "preprocessor")
recursive_pop(task_config, "backbone")
# TODO: remove preprocessor and backbone from task.json to prevent redundancy in config files.
# recursive_pop(task_config, "preprocessor")
# recursive_pop(task_config, "backbone")
task_config["weights"] = weights_filename
with open(task_config_path, "w") as config_file:
config_file.write(json.dumps(task_config, indent=4))
task_weights_path = os.path.join(preset, "task.weights.h5")
task_weights_path = os.path.join(preset, weights_filename)
self.save_task_weights(task_weights_path)

@property
Expand Down

0 comments on commit 53081a9

Please sign in to comment.