diff --git a/keras_nlp/api/models/__init__.py b/keras_nlp/api/models/__init__.py index 766959c988..217363c93a 100644 --- a/keras_nlp/api/models/__init__.py +++ b/keras_nlp/api/models/__init__.py @@ -100,6 +100,7 @@ from keras_nlp.src.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.src.models.falcon.falcon_causal_lm import FalconCausalLM from keras_nlp.src.models.falcon.falcon_causal_lm_preprocessor import ( FalconCausalLMPreprocessor, ) diff --git a/keras_nlp/src/models/__init__.py b/keras_nlp/src/models/__init__.py index bd134f4f4d..ed8a74c6b9 100644 --- a/keras_nlp/src/models/__init__.py +++ b/keras_nlp/src/models/__init__.py @@ -95,6 +95,7 @@ from keras_nlp.src.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.src.models.falcon.falcon_causal_lm import FalconCausalLM from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM diff --git a/keras_nlp/src/models/falcon/falcon_causal_lm_test.py b/keras_nlp/src/models/falcon/falcon_causal_lm_test.py index 0f4c383953..70a5867d7b 100644 --- a/keras_nlp/src/models/falcon/falcon_causal_lm_test.py +++ b/keras_nlp/src/models/falcon/falcon_causal_lm_test.py @@ -60,14 +60,14 @@ def setUp(self): ) self.input_data = self.preprocessor(*self.train_data)[0] - # def test_causal_lm_basics(self): - # vocabulary_size = self.tokenizer.vocabulary_size() - # self.run_task_test( - # cls=FalconCausalLM, - # init_kwargs=self.init_kwargs, - # train_data=self.train_data, - # expected_output_shape=(2, 8, vocabulary_size), - # ) + def test_causal_lm_basics(self): + vocabulary_size = self.tokenizer.vocabulary_size() + self.run_task_test( + cls=FalconCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, vocabulary_size), + ) def test_generate(self): causal_lm = FalconCausalLM(**self.init_kwargs)