Skip to content

Commit

Permalink
Allow a task preprocessor to be an argument in from_preset. (keras-te…
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 26, 2024
1 parent 4c8d0bc commit a5da750
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
5 changes: 4 additions & 1 deletion keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def from_preset(
load_weights=load_weights,
config_overrides=config_overrides,
)
preprocessor = cls.preprocessor_cls.from_preset(preset)
if "preprocessor" in kwargs:
preprocessor = kwargs.pop("preprocessor")
else:
preprocessor = cls.preprocessor_cls.from_preset(preset)
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

def load_task_weights(self, filepath):
Expand Down
10 changes: 10 additions & 0 deletions keras_nlp/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,13 @@ def test_save_to_preset(self):
ref_out = model.predict(data)
new_out = restored_model.predict(data)
self.assertAllEqual(ref_out, new_out)

@pytest.mark.keras_3_only
@pytest.mark.large
def test_none_preprocessor(self):
model = Classifier.from_preset(
"bert_tiny_en_uncased",
preprocessor=None,
num_classes=2,
)
self.assertEqual(model.preprocessor, None)

0 comments on commit a5da750

Please sign in to comment.