Skip to content

Commit

Permalink
feat: Add multilabel classification for training (#191)
Browse files Browse the repository at this point in the history
* Added multilabel option to training

* Added multilabel option to training

* Added multilabel option to training

* Added multilabel option to training

* Added multilabel option to training

* Added multilabel option to training

* Added threshold to predict

* Updated docs

* Updated docs

* Removed fallback logic

* Updated docs

* Updated docs

* Resolved feedback

* Update model2vec/train/README.md

Co-authored-by: Stephan Tulkens <[email protected]>

* Resolved feedback

* Resolved feedback

* Resolved feedback

* Resolved feedback

* add multilabel targets, fix tests (#194)

* Fixed bug with array conversion

* Optimized inference performance

* Changed classes to np array

* Added int as possible label type

* Added int as possible label type

* Use previous logic

* Updated type check

* Updated type check

* Updated type check logic

* Only return object type array for multilabel

---------

Co-authored-by: Stephan Tulkens <[email protected]>
  • Loading branch information
Pringled and stephantul authored Feb 16, 2025
1 parent 84e3fa8 commit 2d51516
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 77 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ from model2vec.train import StaticModelForClassification
# Initialize a classifier from a pre-trained model
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")

# Load a dataset
# Load a dataset. Note: both single and multi-label classification datasets are supported
ds = load_dataset("setfit/subj")

# Train the classifier on text (X) and labels (y)
Expand Down
56 changes: 50 additions & 6 deletions model2vec/inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import huggingface_hub
import numpy as np
import skops.io
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import Pipeline

from model2vec.hf_utils import _create_model_card
Expand All @@ -21,6 +22,20 @@ def __init__(self, model: StaticModel, head: Pipeline) -> None:
"""Create a pipeline with a StaticModel encoder."""
self.model = model
self.head = head
classifier = self.head[-1]
# Check if the classifier is a multilabel classifier.
# NOTE: this doesn't look robust, but it is.
# Different classifiers, such as OVR wrappers, support multilabel output natively, so we
# can just use predict.
self.multilabel = False
if isinstance(classifier, MLPClassifier):
if classifier.out_activation_ == "logistic":
self.multilabel = True

@property
def classes_(self) -> np.ndarray:
"""The classes of the classifier."""
return self.head.classes_

@classmethod
def from_pretrained(
Expand Down Expand Up @@ -60,7 +75,7 @@ def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = Fa
self.model.save_pretrained(temp_dir)
push_folder_to_hub(Path(temp_dir), repo_id, private, token)

def _predict_and_coerce_to_2d(
def _encode_and_coerce_to_2d(
self,
X: list[str] | str,
show_progress_bar: bool,
Expand All @@ -69,7 +84,7 @@ def _predict_and_coerce_to_2d(
use_multiprocessing: bool,
multiprocessing_threshold: int,
) -> np.ndarray:
"""Predict the labels of the input and coerce the output to a matrix."""
"""Encode the instances and coerce the output to a matrix."""
encoded = self.model.encode(
X,
show_progress_bar=show_progress_bar,
Expand All @@ -91,9 +106,21 @@ def predict(
batch_size: int = 1024,
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
threshold: float = 0.5,
) -> np.ndarray:
"""Predict the labels of the input."""
encoded = self._predict_and_coerce_to_2d(
"""
Predict the labels of the input.
:param X: The input data to predict. Can be a list of strings or a single string.
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
:param max_length: The maximum length of the input sequences. Defaults to 512.
:param batch_size: The batch size for prediction. Defaults to 1024.
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
:param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel.
:return: The predicted labels or probabilities.
"""
encoded = self._encode_and_coerce_to_2d(
X,
show_progress_bar=show_progress_bar,
max_length=max_length,
Expand All @@ -102,6 +129,13 @@ def predict(
multiprocessing_threshold=multiprocessing_threshold,
)

if self.multilabel:
out_labels = []
proba = self.head.predict_proba(encoded)
for vector in proba:
out_labels.append(self.classes_[vector > threshold])
return np.asarray(out_labels, dtype=object)

return self.head.predict(encoded)

def predict_proba(
Expand All @@ -113,8 +147,18 @@ def predict_proba(
use_multiprocessing: bool = True,
multiprocessing_threshold: int = 10_000,
) -> np.ndarray:
"""Predict the probabilities of the labels of the input."""
encoded = self._predict_and_coerce_to_2d(
"""
Predict the labels of the input.
:param X: The input data to predict. Can be a list of strings or a single string.
:param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
:param max_length: The maximum length of the input sequences. Defaults to 512.
:param batch_size: The batch size for prediction. Defaults to 1024.
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
:return: The predicted labels or probabilities.
"""
encoded = self._encode_and_coerce_to_2d(
X,
show_progress_bar=show_progress_bar,
max_length=max_length,
Expand Down
50 changes: 50 additions & 0 deletions model2vec/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html).

We support both single and multi-label classification, which work seamlessly based on the labels you provide.

# Installation

To train, make sure you install the training extra:
Expand Down Expand Up @@ -65,6 +67,54 @@ print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} ins
# Took 67 milliseconds for 2000 instances on CPU.
```

## Multi-label classification

Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset:

```python
from datasets import load_dataset
from model2vec.train import StaticModelForClassification

# Initialize a classifier from a pre-trained model
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")

# Load a multi-label dataset
ds = load_dataset("google-research-datasets/go_emotions")

# Inspect some of the labels
print(ds["train"]["labels"][40:50])
# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]]

# Train the classifier on text (X) and labels (y)
classifier.fit(ds["train"]["text"], ds["train"]["labels"])
```

Then, we can evaluate the classifier:

```python
from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer

# Make predictions on the test set with a threshold of 0.3
predictions = classifier.predict(ds["test"]["text"], threshold=0.3)

# Evaluate the classifier
mlb = MultiLabelBinarizer(classes=classifier.classes)
y_true = mlb.fit_transform(ds["test"]["labels"])
y_pred = mlb.transform(predictions)

print(f"Accuracy: {metrics.accuracy_score(y_true, y_pred):.3f}")
print(f"Precision: {metrics.precision_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
print(f"Recall: {metrics.recall_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
# Accuracy: 0.410
# Precision: 0.527
# Recall: 0.410
# F1: 0.439
```

The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster.

# Persistence

You can turn a classifier into a scikit-learn compatible pipeline, as follows:
Expand Down
Loading

0 comments on commit 2d51516

Please sign in to comment.