Skip to content

Commit

Permalink
Merge pull request #466 from tomaarsen/fix/predict_proba_multi_output
Browse files Browse the repository at this point in the history
Resolve crash with predict_proba & multi-output
  • Loading branch information
tomaarsen committed Jan 8, 2024
2 parents 107ca56 + 750d5ba commit 58a3600
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ def predict_proba(
inputs = [inputs]
embeddings = self.encode(inputs, batch_size=batch_size, show_progress_bar=show_progress_bar)
probs = self.model_head.predict_proba(embeddings)
if isinstance(probs, list):
if self.has_differentiable_head:
probs = torch.stack(probs, axis=1)
else:
probs = np.stack(probs, axis=1)
outputs = self._output_type_conversion(probs, as_numpy=as_numpy)
return outputs[0] if is_singular else outputs

Expand Down
32 changes: 30 additions & 2 deletions tests/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import numpy as np
import pytest
import torch
from datasets import load_dataset
from datasets import Dataset, load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import ClassifierChain, MultiOutputClassifier

from setfit import SetFitHead, SetFitModel
from setfit import SetFitHead, SetFitModel, Trainer
from setfit.modeling import MODEL_HEAD_NAME


Expand Down Expand Up @@ -324,3 +324,31 @@ def test_singular_predict() -> None:
assert probs.argmax() == 1
model.labels = ["negative", "positive"]
assert model("That was cool!") == "positive"


# A differentiable head may still cause unexpected performance
@pytest.mark.parametrize("use_differentiable_head", [False])
def test_predict_proba_multi_output(use_differentiable_head: bool) -> None:
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
multi_target_strategy="multi-output",
use_differentiable_head=use_differentiable_head,
)
train_dataset = Dataset.from_dict({"text": ["Hello", "World"], "label": [[1, 0], [0, 1]]})

trainer = Trainer(model=model, train_dataset=train_dataset)
trainer.train()

outputs = model.predict_proba("That was cool!")
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == (2, 2)
outputs = model.predict_proba("That was cool!", as_numpy=True)
assert isinstance(outputs, np.ndarray)
assert outputs.shape == (2, 2)

outputs = model.predict_proba(["That was cool!"] * 3)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == (3, 2, 2)
outputs = model.predict_proba(["That was cool!"] * 3, as_numpy=True)
assert isinstance(outputs, np.ndarray)
assert outputs.shape == (3, 2, 2)

0 comments on commit 58a3600

Please sign in to comment.