Skip to content

Commit

Permalink
Improve messages and docs is task.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 8, 2024
1 parent 616dabe commit f6e7b5a
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,6 @@ def from_preset(
)

# Load task from preset if it exists.
if not issubclass(cls, Task):
raise ValueError(
"`{cls.__name__}` should be subclass of Task!"
) # TODO: update error message
if not issubclass(task_preset_cls, cls):
raise ValueError(
f"`{TASK_CONFIG_FILE}` has type `{task_preset_cls.__name__}` "
Expand All @@ -348,11 +344,10 @@ def load_weights(self, filepath, skip_mismatch=False):
"The filename must end in `.weights.h5`. "
f"Received: filepath={filepath}"
)
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
weights_store = keras.src.saving.saving_lib.H5IOStore(
filepath, mode="r"
)
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
# TODO: It's better not to use this private API here. Francois recommends chaning our public saving API and skip objects to it. Francoins will do this.
keras.src.saving.saving_lib._load_state(
self,
weights_store=weights_store,
Expand All @@ -363,6 +358,12 @@ def load_weights(self, filepath, skip_mismatch=False):
failed_trackables=set(),
)
weights_store.close()
# TODO: use the following for weight loading when a new version of keras is released.
# keras.saving.save_weights(
# self,
# filepath,
# objecst_to_skip=backbone_layer_ids,
# )

def save_weights(self, filepath):
"""Save only the tasks specific weights not in the backbone."""
Expand Down Expand Up @@ -390,13 +391,19 @@ def save_weights(self, filepath):
visited_trackables=backbone_layer_ids,
)
weights_store.close()
# TODO: use the following for weight loading.
# keras.saving.load_weights(filepath, objects_to_skip=backbone_layer_ids)

def save_to_preset(self, preset):
"""TODO: add docstring"""
"""Save task to a preset directory.
Args:
preset: The path to the local model preset directory.
"""
if self.preprocessor is None:
raise ValueError(
"Preprocessor is not defined!"
) # TODO: improve error message
"Cannot save `task` to preset: `Preprocessor` is not initialized."
)

self.preprocessor.save_to_preset(preset)
self.backbone.save_to_preset(preset)
Expand Down

0 comments on commit f6e7b5a

Please sign in to comment.