Skip to content

Commit

Permalink
fix serder
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffpicard committed Dec 14, 2024
1 parent 8ae1ab8 commit 4ce1fd8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
10 changes: 5 additions & 5 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,26 +1101,26 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:

if saved_config is None:
config = AutoConfig.from_pretrained(
model, output_hidden_states=True, **transformers_config_kwargs, **kwargs
model, output_hidden_states=True, **transformers_config_kwargs
)

if is_supported_t5_model(config):
from transformers import T5EncoderModel

transformer_model = T5EncoderModel.from_pretrained(
model, config=config, **transformers_model_kwargs, **kwargs
model, config=config, **transformers_model_kwargs
)
else:
transformer_model = AutoModel.from_pretrained(
model, config=config, **transformers_model_kwargs, **kwargs
model, config=config, **transformers_model_kwargs
)
else:
if is_supported_t5_model(saved_config):
from transformers import T5EncoderModel

transformer_model = T5EncoderModel(saved_config, **transformers_model_kwargs, **kwargs)
transformer_model = T5EncoderModel(saved_config, **transformers_model_kwargs)
else:
transformer_model = AutoModel.from_config(saved_config, **transformers_model_kwargs, **kwargs)
transformer_model = AutoModel.from_config(saved_config, **transformers_model_kwargs)
try:
transformer_model = transformer_model.to(flair.device)
except ValueError as e:
Expand Down
7 changes: 7 additions & 0 deletions tests/embedding_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from typing import Any, Optional

import pytest
Expand Down Expand Up @@ -183,3 +184,9 @@ def test_default_embeddings_stay_the_same_after_saving_and_loading(self):
def test_embeddings_load_in_eval_mode(self):
embeddings = self.create_embedding_with_args(self.default_args)
assert not embeddings.training

def test_serializable(self):
embeddings = self.create_embedding_with_args(self.default_args)
serialized = pickle.dumps(embeddings)
deserialized = pickle.loads(serialized)
assert deserialized is not None
9 changes: 9 additions & 0 deletions tests/embeddings/test_transformer_document_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from flair.nn import LabelVerbalizerDecoder
from flair.data import Dictionary, Sentence
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
Expand Down Expand Up @@ -41,6 +42,14 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path):
assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator


def test_loading_complex_models(tasks_base_path):
embeddings = TransformerDocumentEmbeddings("distilbert-base-uncased")
model = TextClassifier(label_type="ner", label_dictionary=Dictionary(), embeddings=embeddings, decoder=LabelVerbalizerDecoder(embeddings, Dictionary()))
model.save(tasks_base_path / "single.pt")
loaded_embeddings = Classifier.load(tasks_base_path / "single.pt")
assert loaded_embeddings is not None


@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"])
def test_cls_pooling(cls_pooling):
embeddings = TransformerDocumentEmbeddings(
Expand Down

0 comments on commit 4ce1fd8

Please sign in to comment.