diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index fdb16eea2..d1a7de939 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -1101,26 +1101,26 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: if saved_config is None: config = AutoConfig.from_pretrained( - model, output_hidden_states=True, **transformers_config_kwargs, **kwargs + model, output_hidden_states=True, **transformers_config_kwargs ) if is_supported_t5_model(config): from transformers import T5EncoderModel transformer_model = T5EncoderModel.from_pretrained( - model, config=config, **transformers_model_kwargs, **kwargs + model, config=config, **transformers_model_kwargs ) else: transformer_model = AutoModel.from_pretrained( - model, config=config, **transformers_model_kwargs, **kwargs + model, config=config, **transformers_model_kwargs ) else: if is_supported_t5_model(saved_config): from transformers import T5EncoderModel - transformer_model = T5EncoderModel(saved_config, **transformers_model_kwargs, **kwargs) + transformer_model = T5EncoderModel(saved_config, **transformers_model_kwargs) else: - transformer_model = AutoModel.from_config(saved_config, **transformers_model_kwargs, **kwargs) + transformer_model = AutoModel.from_config(saved_config, **transformers_model_kwargs) try: transformer_model = transformer_model.to(flair.device) except ValueError as e: diff --git a/tests/embedding_test_utils.py b/tests/embedding_test_utils.py index c1a0b1a79..be44d1600 100644 --- a/tests/embedding_test_utils.py +++ b/tests/embedding_test_utils.py @@ -1,3 +1,4 @@ +import pickle from typing import Any, Optional import pytest @@ -183,3 +184,9 @@ def test_default_embeddings_stay_the_same_after_saving_and_loading(self): def test_embeddings_load_in_eval_mode(self): embeddings = self.create_embedding_with_args(self.default_args) assert not embeddings.training + + def test_serializable(self): + embeddings = self.create_embedding_with_args(self.default_args) + serialized = pickle.dumps(embeddings) + deserialized = pickle.loads(serialized) + assert deserialized is not None diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index f0f6389b7..339a7d553 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -1,5 +1,6 @@ import pytest +from flair.nn import LabelVerbalizerDecoder from flair.data import Dictionary, Sentence from flair.embeddings import TransformerDocumentEmbeddings from flair.models import TextClassifier @@ -41,6 +42,14 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path): assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator +def test_loading_complex_models(tasks_base_path): + embeddings = TransformerDocumentEmbeddings("distilbert-base-uncased") + model = TextClassifier(label_type="ner", label_dictionary=Dictionary(), embeddings=embeddings, decoder=LabelVerbalizerDecoder(embeddings, Dictionary())) + model.save(tasks_base_path / "single.pt") + loaded_embeddings = Classifier.load(tasks_base_path / "single.pt") + assert loaded_embeddings is not None + + @pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"]) def test_cls_pooling(cls_pooling): embeddings = TransformerDocumentEmbeddings(