diff --git a/src/setfit/modeling.py b/src/setfit/modeling.py index 2cef9521..02330e97 100644 --- a/src/setfit/modeling.py +++ b/src/setfit/modeling.py @@ -2,11 +2,9 @@ import os import tempfile import warnings -from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Set, Tuple, Union - # For Python 3.7 compatibility try: from typing import Literal @@ -17,12 +15,11 @@ import numpy as np import requests import torch -from huggingface_hub import PyTorchModelHubMixin, hf_hub_download +from huggingface_hub import ModelHubMixin, hf_hub_download from huggingface_hub.utils import validate_hf_hub_args from packaging.version import Version, parse -from sentence_transformers import SentenceTransformer +from sentence_transformers import SentenceTransformer, models from sentence_transformers import __version__ as sentence_transformers_version -from sentence_transformers import models from sklearn.linear_model import LogisticRegression from sklearn.multiclass import OneVsRestClassifier from sklearn.multioutput import ClassifierChain, MultiOutputClassifier @@ -36,7 +33,6 @@ from .model_card import SetFitModelCardData, generate_model_card from .utils import set_docstring - logging.set_verbosity_info() logger = logging.get_logger(__name__) @@ -196,8 +192,7 @@ def __repr__(self) -> str: return "SetFitHead({})".format(self.get_config_dict()) -@dataclass -class SetFitModel(PyTorchModelHubMixin): +class SetFitModel(ModelHubMixin): """A SetFit model with integration to the [Hugging Face Hub](https://huggingface.co). Example:: @@ -212,19 +207,27 @@ class SetFitModel(PyTorchModelHubMixin): ['positive', 'negative', 'negative'] """ - model_body: Optional[SentenceTransformer] = None - model_head: Optional[Union[SetFitHead, LogisticRegression]] = None - multi_target_strategy: Optional[str] = None - normalize_embeddings: bool = False - labels: Optional[List[str]] = None - model_card_data: Optional[SetFitModelCardData] = field(default_factory=SetFitModelCardData) - sentence_transformers_kwargs: Dict = field(default_factory=dict, repr=False) - - attributes_to_save: Set[str] = field( - init=False, repr=False, default_factory=lambda: {"normalize_embeddings", "labels"} - ) - - def __post_init__(self): + def __init__( + self, + model_body: Optional[SentenceTransformer] = None, + model_head: Optional[Union[SetFitHead, LogisticRegression]] = None, + multi_target_strategy: Optional[str] = None, + normalize_embeddings: bool = False, + labels: Optional[List[str]] = None, + model_card_data: Optional[SetFitModelCardData] = None, + sentence_transformers_kwargs: Optional[Dict] = None, + **kwargs, + ) -> None: + super(SetFitModel, self).__init__() + self.model_body = model_body + self.model_head = model_head + self.multi_target_strategy = multi_target_strategy + self.normalize_embeddings = normalize_embeddings + self.labels = labels + self.model_card_data = model_card_data or SetFitModelCardData() + self.sentence_transformers_kwargs = sentence_transformers_kwargs or {} + + self.attributes_to_save: Set[str] = {"normalize_embeddings", "labels"} self.model_card_data.register_model(self) @property diff --git a/src/setfit/span/modeling.py b/src/setfit/span/modeling.py index a6e45746..cd3f312d 100644 --- a/src/setfit/span/modeling.py +++ b/src/setfit/span/modeling.py @@ -4,9 +4,9 @@ import tempfile import types from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import torch from datasets import Dataset @@ -18,25 +18,25 @@ from ..modeling import SetFitModel from .aspect_extractor import AspectExtractor - if TYPE_CHECKING: from spacy.tokens import Doc logger = logging.get_logger(__name__) -@dataclass class SpanSetFitModel(SetFitModel): - spacy_model: str = "en_core_web_lg" - span_context: int = 0 - - attributes_to_save: Set[str] = field( - init=False, - repr=False, - default_factory=lambda: {"normalize_embeddings", "labels", "span_context", "spacy_model"}, - ) + def __init__( + self, + spacy_model: str = "en_core_web_lg", + span_context: int = 0, + **kwargs, + ): + super().__init__(**kwargs) + self.spacy_model = spacy_model + self.span_context = span_context + self.attributes_to_save = {"normalize_embeddings", "labels", "span_context", "spacy_model"} - def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]: + def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> Iterable[str]: for doc, aspects in zip(docs, aspects_list): for aspect_slice in aspects: aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context] @@ -137,9 +137,10 @@ def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[b AspectModel.from_pretrained = types.MethodType(AspectModel.from_pretrained.__func__, AspectModel) -@dataclass class PolarityModel(SpanSetFitModel): - span_context: int = 3 + def __init__(self, span_context: int = 3, **kwargs): + super().__init__(**kwargs) + self.span_context = span_context PolarityModel.from_pretrained = types.MethodType(PolarityModel.from_pretrained.__func__, PolarityModel) diff --git a/tests/span/aspect_model_card_pattern.py b/tests/span/aspect_model_card_pattern.py index 8295a25a..f542ffca 100644 --- a/tests/span/aspect_model_card_pattern.py +++ b/tests/span/aspect_model_card_pattern.py @@ -16,6 +16,7 @@ - sentence-transformers - text-classification - generated_from_setfit_trainer +base_model: sentence-transformers/paraphrase-albert-small-v2 metrics: - accuracy widget: @@ -31,8 +32,7 @@ ram_total_size: [\d\.]+ hours_used: [\d\.]+ ( hardware_used: .+ -)?base_model: sentence-transformers/paraphrase-albert-small-v2 -model-index: +)?model-index: - name: SetFit Aspect Model with sentence-transformers\/paraphrase-albert-small-v2 results: - task: diff --git a/tests/span/polarity_model_card_pattern.py b/tests/span/polarity_model_card_pattern.py index 921ad2bb..64046801 100644 --- a/tests/span/polarity_model_card_pattern.py +++ b/tests/span/polarity_model_card_pattern.py @@ -16,6 +16,7 @@ - sentence-transformers - text-classification - generated_from_setfit_trainer +base_model: sentence-transformers/paraphrase-albert-small-v2 metrics: - accuracy widget: @@ -31,8 +32,7 @@ ram_total_size: [\d\.]+ hours_used: [\d\.]+ ( hardware_used: .+ -)?base_model: sentence-transformers/paraphrase-albert-small-v2 -model-index: +)?model-index: - name: SetFit Polarity Model with sentence-transformers\/paraphrase-albert-small-v2 results: - task: