Skip to content

Commit

Permalink
Import FalconCausalLM in inits and uncomment a test.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed May 20, 2024
1 parent 000e7d2 commit 1237565
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
1 change: 1 addition & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from keras_nlp.src.models.f_net.f_net_preprocessor import FNetPreprocessor
from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.src.models.falcon.falcon_causal_lm import FalconCausalLM
from keras_nlp.src.models.falcon.falcon_causal_lm_preprocessor import (
FalconCausalLMPreprocessor,
)
Expand Down
1 change: 1 addition & 0 deletions keras_nlp/src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from keras_nlp.src.models.f_net.f_net_preprocessor import FNetPreprocessor
from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.src.models.falcon.falcon_causal_lm import FalconCausalLM
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM
Expand Down
16 changes: 8 additions & 8 deletions keras_nlp/src/models/falcon/falcon_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def setUp(self):
)
self.input_data = self.preprocessor(*self.train_data)[0]

# def test_causal_lm_basics(self):
# vocabulary_size = self.tokenizer.vocabulary_size()
# self.run_task_test(
# cls=FalconCausalLM,
# init_kwargs=self.init_kwargs,
# train_data=self.train_data,
# expected_output_shape=(2, 8, vocabulary_size),
# )
def test_causal_lm_basics(self):
vocabulary_size = self.tokenizer.vocabulary_size()
self.run_task_test(
cls=FalconCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 8, vocabulary_size),
)

def test_generate(self):
causal_lm = FalconCausalLM(**self.init_kwargs)
Expand Down

0 comments on commit 1237565

Please sign in to comment.