Skip to content

Commit

Permalink
Updated evaluate, updated tests to also include int type labels
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 16, 2025
1 parent b6c00b8 commit a51f0bb
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
25 changes: 14 additions & 11 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,16 @@ def fit(
return self

def evaluate(
self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5
) -> dict[str, dict[str, float]]:
self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
) -> str | dict[str, dict[str, float]]:
"""
Evaluate the classifier on a given dataset using scikit-learn's classification report.
:param X: The texts to predict on.
:param y: The ground truth labels.
:param batch_size: The batch size.
:param threshold: The threshold for multilabel classification.
:param output_dict: Whether to output the classification report as a dictionary.
:return: A classification report.
"""
self.eval()
Expand All @@ -246,25 +247,27 @@ def evaluate(
if not self.multilabel:
# Encode the labels using a LabelEncoder
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(y)
label_idx = label_encoder.fit_transform(self.classes_)
y = label_encoder.transform(y)
predictions = label_encoder.transform(predictions)
report = classification_report(
y,
predictions,
target_names=self.classes_,
output_dict=True,
labels=label_idx,
target_names=[str(c) for c in self.classes_],
output_dict=output_dict,
zero_division=0,
)
else:
# Encode the labels using a MultiLabelBinarizer
mlb = MultiLabelBinarizer(classes=self.classes)
y_true = mlb.fit_transform(y)
y_pred = mlb.transform(predictions)
y = mlb.fit_transform(y)
predictions = mlb.transform(predictions)
report = classification_report(
y_true,
y_pred,
target_names=mlb.classes_,
output_dict=True,
y,
predictions,
target_names=[str(c) for c in mlb.classes_],
output_dict=output_dict,
zero_division=0,
)

Expand Down
24 changes: 16 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,30 @@ def mock_inference_pipeline(mock_trained_pipeline: StaticModelForClassification)
return mock_trained_pipeline.to_pipeline()


@pytest.fixture(params=[False, True], ids=["single_label", "multilabel"], scope="session")
@pytest.fixture(
params=[
(False, "single_label", "str"),
(False, "single_label", "int"),
(True, "multilabel", "str"),
(True, "multilabel", "int"),
],
ids=lambda param: f"{param[1]}_{param[2]}",
scope="session",
)
def mock_trained_pipeline(request: pytest.FixtureRequest) -> StaticModelForClassification:
"""Mock staticmodelforclassification."""
"""Mock StaticModelForClassification with different label formats."""
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
torch.random.manual_seed(42)
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")

X = ["dog", "cat"]
y: list[str] | list[list[str]]
if request.param:
# Use multilabel targets.
y = [["a", "b"], ["a"]]
is_multilabel, label_type = request.param[0], request.param[2]

if label_type == "str":
y = [["a", "b"], ["a"]] if is_multilabel else ["a", "b"] # type: ignore
else:
# Use singlelabel targets.
y = ["a", "b"]
y = [[0, 1], [0]] if is_multilabel else [0, 1] # type: ignore

model.fit(X, y)

Expand Down
26 changes: 24 additions & 2 deletions tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,15 @@ def test_predict(mock_trained_pipeline: StaticModelForClassification) -> None:
"""Test the predict function."""
result = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist()
if mock_trained_pipeline.multilabel:
assert result == [["a", "b"], ["a", "b"]]
if type(mock_trained_pipeline.classes_[0]) == str:
assert result == [["a", "b"], ["a", "b"]]
else:
assert result == [[0, 1], [0, 1]]
else:
assert result == ["b", "b"]
if type(mock_trained_pipeline.classes_[0]) == str:
assert result == ["b", "b"]
else:
assert result == [1, 1]


def test_predict_proba(mock_trained_pipeline: StaticModelForClassification) -> None:
Expand Down Expand Up @@ -146,3 +152,19 @@ def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -
assert len(b) == 2
assert len(c) == len(a)
assert len(d) == len(b)


def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None:
"""Test the evaluate function."""
if mock_trained_pipeline.multilabel:
if type(mock_trained_pipeline.classes_[0]) == str:
mock_trained_pipeline.evaluate(["dog cat", "dog"], [["a", "b"], ["a"]])
else:
# Ignore the type error since we don't support int labels in our typing, but the code does
mock_trained_pipeline.evaluate(["dog cat", "dog"], [[0, 1], [0]]) # type: ignore
else:
if type(mock_trained_pipeline.classes_[0]) == str:
mock_trained_pipeline.evaluate(["dog cat", "dog"], ["a", "a"])
else:
# Ignore the type error since we don't support int labels in our typing, but the code does
mock_trained_pipeline.evaluate(["dog cat", "dog"], [1, 1]) # type: ignore

0 comments on commit a51f0bb

Please sign in to comment.