Skip to content

Commit

Permalink
Updated classes logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 17, 2025
1 parent 69d990a commit 065e04d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
9 changes: 4 additions & 5 deletions model2vec/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ def evaluate(
:return: A classification report.
"""
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
report = evaluate_single_or_multi_label(
predictions=predictions, y=y, classes=self.classes_, output_dict=output_dict
)
report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)

return report

Expand Down Expand Up @@ -279,22 +277,23 @@ def _is_multi_label_shaped(y: LabelType) -> bool:
def evaluate_single_or_multi_label(
predictions: np.ndarray,
y: LabelType,
classes: np.ndarray,
output_dict: bool = False,
) -> str | dict[str, dict[str, float]]:
"""
Evaluate the classifier on a given dataset using scikit-learn's classification report.
:param predictions: The predictions.
:param y: The ground truth labels.
:param classes: The classes of the classifier.
:param output_dict: Whether to output the classification report as a dictionary.
:return: A classification report.
"""
if _is_multi_label_shaped(y):
classes = sorted(set([label for labels in y for label in labels]))
mlb = MultiLabelBinarizer(classes=classes)
y = mlb.fit_transform(y)
predictions = mlb.transform(predictions)
elif isinstance(y[0], (str, int)):
classes = sorted(set(y))

report = classification_report(
y,
Expand Down
4 changes: 1 addition & 3 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,7 @@ def evaluate(
"""
self.eval()
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
report = evaluate_single_or_multi_label(
predictions=predictions, y=y, classes=self.classes, output_dict=output_dict
)
report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)

return report

Expand Down

0 comments on commit 065e04d

Please sign in to comment.