From a5da750aceb70ed02000dc0d196e2ceb9d65ba7d Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Fri, 26 Apr 2024 16:49:25 -0700 Subject: [PATCH] Allow a task preprocessor to be an argument in from_preset. (#1603) --- keras_nlp/models/task.py | 5 ++++- keras_nlp/models/task_test.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index b993240e4e..bb4e389e32 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -282,7 +282,10 @@ def from_preset( load_weights=load_weights, config_overrides=config_overrides, ) - preprocessor = cls.preprocessor_cls.from_preset(preset) + if "preprocessor" in kwargs: + preprocessor = kwargs.pop("preprocessor") + else: + preprocessor = cls.preprocessor_cls.from_preset(preset) return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) def load_task_weights(self, filepath): diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py index bc016f1731..6c6aa2543b 100644 --- a/keras_nlp/models/task_test.py +++ b/keras_nlp/models/task_test.py @@ -138,3 +138,13 @@ def test_save_to_preset(self): ref_out = model.predict(data) new_out = restored_model.predict(data) self.assertAllEqual(ref_out, new_out) + + @pytest.mark.keras_3_only + @pytest.mark.large + def test_none_preprocessor(self): + model = Classifier.from_preset( + "bert_tiny_en_uncased", + preprocessor=None, + num_classes=2, + ) + self.assertEqual(model.preprocessor, None)