diff --git a/keras_nlp/src/models/preprocessor.py b/keras_nlp/src/models/preprocessor.py index 10d8e93ce..d48cbd970 100644 --- a/keras_nlp/src/models/preprocessor.py +++ b/keras_nlp/src/models/preprocessor.py @@ -181,7 +181,7 @@ def from_preset( tokenizer = load_serialized_object(preset, TOKENIZER_CONFIG_FILE) tokenizer.load_preset_assets(preset) - preprocessor = cls(tokenizer=tokenizer) + preprocessor = cls(tokenizer=tokenizer, **kwargs) return preprocessor diff --git a/keras_nlp/src/models/preprocessor_test.py b/keras_nlp/src/models/preprocessor_test.py index 71c78a3bd..9837bf443 100644 --- a/keras_nlp/src/models/preprocessor_test.py +++ b/keras_nlp/src/models/preprocessor_test.py @@ -52,6 +52,13 @@ def test_from_preset(self): BertMaskedLMPreprocessor, ) + @pytest.mark.large + def test_from_preset_with_sequence_length(self): + preprocessor = BertPreprocessor.from_preset( + "bert_tiny_en_uncased", sequence_length=16 + ) + self.assertEqual(preprocessor.sequence_length, 16) + @pytest.mark.large def test_from_preset_errors(self): with self.assertRaises(ValueError):