diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index eb3ff1bb49..4e5b361e45 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -300,16 +300,15 @@ def from_preset( if load_weights: task.load_weights(get_file(preset, TASK_WEIGHTS_FILE)) task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE)) - # TODO: is task.preprocessor None before this assignment? + # `task.preprocessor` is None before this assignment. task.preprocessor = preprocessor return task - def load_weights(self, filepath, skip_mismatch=False): + def load_weights(self, filepath): """Load only the tasks specific weights not in the backbone.""" if not str(filepath).endswith(".weights.h5"): raise ValueError( - "The filename must end in `.weights.h5`. " - f"Received: filepath={filepath}" + "The filename must end in `.weights.h5`. Received: filepath={filepath}" ) backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers()) keras.saving.save_weights(