Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SetFitModel: not a dataclass, not a PyTorchModelHubMixin #505

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import os
import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple, Union


# For Python 3.7 compatibility
try:
from typing import Literal
Expand All @@ -17,12 +15,11 @@
import numpy as np
import requests
import torch
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from huggingface_hub import ModelHubMixin, hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from packaging.version import Version, parse
from sentence_transformers import SentenceTransformer
from sentence_transformers import SentenceTransformer, models
from sentence_transformers import __version__ as sentence_transformers_version
from sentence_transformers import models
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier
Expand All @@ -36,7 +33,6 @@
from .model_card import SetFitModelCardData, generate_model_card
from .utils import set_docstring


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -196,8 +192,7 @@ def __repr__(self) -> str:
return "SetFitHead({})".format(self.get_config_dict())


@dataclass
class SetFitModel(PyTorchModelHubMixin):
class SetFitModel(ModelHubMixin):
"""A SetFit model with integration to the [Hugging Face Hub](https://huggingface.co).

Example::
Expand All @@ -212,19 +207,27 @@ class SetFitModel(PyTorchModelHubMixin):
['positive', 'negative', 'negative']
"""

model_body: Optional[SentenceTransformer] = None
model_head: Optional[Union[SetFitHead, LogisticRegression]] = None
multi_target_strategy: Optional[str] = None
normalize_embeddings: bool = False
labels: Optional[List[str]] = None
model_card_data: Optional[SetFitModelCardData] = field(default_factory=SetFitModelCardData)
sentence_transformers_kwargs: Dict = field(default_factory=dict, repr=False)

attributes_to_save: Set[str] = field(
init=False, repr=False, default_factory=lambda: {"normalize_embeddings", "labels"}
)

def __post_init__(self):
def __init__(
self,
model_body: Optional[SentenceTransformer] = None,
model_head: Optional[Union[SetFitHead, LogisticRegression]] = None,
multi_target_strategy: Optional[str] = None,
normalize_embeddings: bool = False,
labels: Optional[List[str]] = None,
model_card_data: Optional[SetFitModelCardData] = None,
sentence_transformers_kwargs: Optional[Dict] = None,
**kwargs,
) -> None:
super(SetFitModel, self).__init__()
self.model_body = model_body
self.model_head = model_head
self.multi_target_strategy = multi_target_strategy
self.normalize_embeddings = normalize_embeddings
self.labels = labels
self.model_card_data = model_card_data or SetFitModelCardData()
self.sentence_transformers_kwargs = sentence_transformers_kwargs or {}

self.attributes_to_save: Set[str] = {"normalize_embeddings", "labels"}
self.model_card_data.register_model(self)

@property
Expand Down
31 changes: 16 additions & 15 deletions src/setfit/span/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import tempfile
import types
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from datasets import Dataset
Expand All @@ -18,25 +18,25 @@
from ..modeling import SetFitModel
from .aspect_extractor import AspectExtractor


if TYPE_CHECKING:
from spacy.tokens import Doc

logger = logging.get_logger(__name__)


@dataclass
class SpanSetFitModel(SetFitModel):
spacy_model: str = "en_core_web_lg"
span_context: int = 0

attributes_to_save: Set[str] = field(
init=False,
repr=False,
default_factory=lambda: {"normalize_embeddings", "labels", "span_context", "spacy_model"},
)
def __init__(
self,
spacy_model: str = "en_core_web_lg",
span_context: int = 0,
**kwargs,
):
super().__init__(**kwargs)
self.spacy_model = spacy_model
self.span_context = span_context
self.attributes_to_save = {"normalize_embeddings", "labels", "span_context", "spacy_model"}

def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]:
def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> Iterable[str]:
for doc, aspects in zip(docs, aspects_list):
for aspect_slice in aspects:
aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context]
Expand Down Expand Up @@ -137,9 +137,10 @@ def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[b
AspectModel.from_pretrained = types.MethodType(AspectModel.from_pretrained.__func__, AspectModel)


@dataclass
class PolarityModel(SpanSetFitModel):
span_context: int = 3
def __init__(self, span_context: int = 3, **kwargs):
super().__init__(**kwargs)
self.span_context = span_context


PolarityModel.from_pretrained = types.MethodType(PolarityModel.from_pretrained.__func__, PolarityModel)
Expand Down
4 changes: 2 additions & 2 deletions tests/span/aspect_model_card_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- sentence-transformers
- text-classification
- generated_from_setfit_trainer
base_model: sentence-transformers/paraphrase-albert-small-v2
metrics:
- accuracy
widget:
Expand All @@ -31,8 +32,7 @@
ram_total_size: [\d\.]+
hours_used: [\d\.]+
( hardware_used: .+
)?base_model: sentence-transformers/paraphrase-albert-small-v2
model-index:
)?model-index:
- name: SetFit Aspect Model with sentence-transformers\/paraphrase-albert-small-v2
results:
- task:
Expand Down
4 changes: 2 additions & 2 deletions tests/span/polarity_model_card_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- sentence-transformers
- text-classification
- generated_from_setfit_trainer
base_model: sentence-transformers/paraphrase-albert-small-v2
metrics:
- accuracy
widget:
Expand All @@ -31,8 +32,7 @@
ram_total_size: [\d\.]+
hours_used: [\d\.]+
( hardware_used: .+
)?base_model: sentence-transformers/paraphrase-albert-small-v2
model-index:
)?model-index:
- name: SetFit Polarity Model with sentence-transformers\/paraphrase-albert-small-v2
results:
- task:
Expand Down
Loading