Skip to content

Commit

Permalink
add multilabel targets, fix tests (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul authored Feb 15, 2025
1 parent 43de6da commit 8e944ab
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 14 deletions.
56 changes: 50 additions & 6 deletions model2vec/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import huggingface_hub
import numpy as np
import skops.io
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import Pipeline

from model2vec.hf_utils import _create_model_card
Expand All @@ -21,6 +22,20 @@ def __init__(self, model: StaticModel, head: Pipeline) -> None:
"""Create a pipeline with a StaticModel encoder."""
self.model = model
self.head = head
classifier = self.head[-1]
# Check if the classifier is a multilabel classifier.
# NOTE: this doesn't look robust, but it is.
# Different classifiers, such as OVR wrappers, support multilabel output natively, so we
# can just use predict.
self.multilabel = False
if isinstance(classifier, MLPClassifier):
if classifier.out_activation_ == "logistic":
self.multilabel = True

@property
def classes_(self) -> np.ndarray:
"""The classes of the classifier."""
return self.head.classes_

@classmethod
def from_pretrained(
Expand Down Expand Up @@ -60,7 +75,7 @@ def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = Fa
self.model.save_pretrained(temp_dir)
push_folder_to_hub(Path(temp_dir), repo_id, private, token)

def _predict_and_coerce_to_2d(
def _encode_and_coerce_to_2d(
self,
X: list[str] | str,
show_progress_bar: bool,
Expand All @@ -69,7 +84,7 @@ def _predict_and_coerce_to_2d(
use_multiprocessing: bool,
multiprocessing_threshold: int,
) -> np.ndarray:
"""Predict the labels of the input and coerce the output to a matrix."""
"""Encode the instances and coerce the output to a matrix."""
encoded = self.model.encode(
X,
show_progress_bar=show_progress_bar,
Expand All @@ -91,9 +106,21 @@ def predict(
batch_size: int = 1024,
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
threshold: float = 0.5,
) -> np.ndarray:
"""Predict the labels of the input."""
encoded = self._predict_and_coerce_to_2d(
"""
Predict the labels of the input.
:param X: The input data to predict. Can be a list of strings or a single string.
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
:param max_length: The maximum length of the input sequences. Defaults to 512.
:param batch_size: The batch size for prediction. Defaults to 1024.
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
:param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel.
:return: The predicted labels or probabilities.
"""
encoded = self._encode_and_coerce_to_2d(
X,
show_progress_bar=show_progress_bar,
max_length=max_length,
Expand All @@ -102,6 +129,13 @@ def predict(
multiprocessing_threshold=multiprocessing_threshold,
)

if self.multilabel:
out_labels = []
proba = self.head.predict_proba(encoded)
for vector in proba:
out_labels.append(self.classes_[vector > threshold])
return np.asarray(out_labels)

return self.head.predict(encoded)

def predict_proba(
Expand All @@ -113,8 +147,18 @@ def predict_proba(
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
) -> np.ndarray:
"""Predict the probabilities of the labels of the input."""
encoded = self._predict_and_coerce_to_2d(
"""
Predict the labels of the input.
:param X: The input data to predict. Can be a list of strings or a single string.
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
:param max_length: The maximum length of the input sequences. Defaults to 512.
:param batch_size: The batch size for prediction. Defaults to 1024.
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
:return: The predicted labels or probabilities.
"""
encoded = self._encode_and_coerce_to_2d(
X,
show_progress_bar=show_progress_bar,
max_length=max_length,
Expand Down
5 changes: 2 additions & 3 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def to_pipeline(self) -> StaticModelPipeline:
# To convert correctly, we need to set the outputs correctly, and fix the activation function.
# Make sure n_outputs is set to > 1.
mlp_head.n_outputs_ = self.out_dim
# Set to softmax
mlp_head.out_activation_ = "softmax"
# Set to softmax or sigmoid
mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax"

return StaticModelPipeline(static_model, converted)

Expand Down Expand Up @@ -373,7 +373,6 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
mode="min",
factor=0.5,
patience=3,
verbose=True,
min_lr=1e-6,
threshold=0.03,
threshold_mode="rel",
Expand Down
18 changes: 14 additions & 4 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@

def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None:
"""Test successful initialization of StaticModelPipeline."""
assert mock_inference_pipeline.predict("dog").tolist() == ["b"]
assert mock_inference_pipeline.predict(["dog"]).tolist() == ["b"]
target: list[str] | list[list[str]]
if mock_inference_pipeline.multilabel:
target = [["a", "b"]]
else:
target = ["b"]
assert mock_inference_pipeline.predict("dog").tolist() == target
assert mock_inference_pipeline.predict(["dog"]).tolist() == target


def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> None:
Expand All @@ -25,8 +30,13 @@ def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None:
with TemporaryDirectory() as temp_dir:
mock_inference_pipeline.save_pretrained(temp_dir)
loaded = StaticModelPipeline.from_pretrained(temp_dir)
assert loaded.predict("dog") == ["b"]
assert loaded.predict(["dog"]) == ["b"]
target: list[str] | list[list[str]]
if mock_inference_pipeline.multilabel:
target = [["a", "b"]]
else:
target = ["b"]
assert loaded.predict("dog").tolist() == target
assert loaded.predict(["dog"]).tolist() == target
assert loaded.predict_proba("dog").argmax() == 1
assert loaded.predict_proba(["dog"]).argmax(1).tolist() == [1]

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8e944ab

Please sign in to comment.