Skip to content

Commit

Permalink
Random fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 10, 2024
1 parent fa3c6fe commit c238621
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,15 @@ def from_preset(
if load_weights:
task.load_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
# TODO: is task.preprocessor None before this assignment?
# `task.preprocessor` is None before this assignment.
task.preprocessor = preprocessor
return task

def load_weights(self, filepath, skip_mismatch=False):
def load_weights(self, filepath):
"""Load only the tasks specific weights not in the backbone."""
if not str(filepath).endswith(".weights.h5"):
raise ValueError(
"The filename must end in `.weights.h5`. "
f"Received: filepath={filepath}"
"The filename must end in `.weights.h5`. Received: filepath={filepath}"
)
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
keras.saving.save_weights(
Expand Down

0 comments on commit c238621

Please sign in to comment.