Skip to content

Commit

Permalink
Updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Feb 16, 2025
1 parent a51f0bb commit 1c86d5e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 17 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ ds = load_dataset("setfit/subj")
classifier.fit(ds["train"]["text"], ds["train"]["label"])

# Evaluate the classifier
predictions = classifier.predict(ds["test"]["text"])
accuracy = np.mean(np.array(predictions) == np.array(ds["test"]["label"])) * 100
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
```

For advanced usage, please refer to our [usage documentation](https://github.com/MinishLab/model2vec/blob/main/docs/usage.md).
Expand Down
19 changes: 4 additions & 15 deletions model2vec/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ test = ds["test"]
s = perf_counter()
classifier = classifier.fit(train["text"], train["label"])

predicted = classifier.predict(test["text"])
print(f"Training took {int(perf_counter() - s)} seconds.")
# Training took 81 seconds
accuracy = np.mean([x == y for x, y in zip(predicted, test["label"])]) * 100
print(f"Achieved {accuracy} test accuracy")
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
print(classification_report)
# Achieved 91.0 test accuracy
```

Expand Down Expand Up @@ -95,18 +94,8 @@ Then, we can evaluate the classifier:
from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer

# Make predictions on the test set with a threshold of 0.3
predictions = classifier.predict(ds["test"]["text"], threshold=0.3)

# Evaluate the classifier
mlb = MultiLabelBinarizer(classes=classifier.classes)
y_true = mlb.fit_transform(ds["test"]["labels"])
y_pred = mlb.transform(predictions)

print(f"Accuracy: {metrics.accuracy_score(y_true, y_pred):.3f}")
print(f"Precision: {metrics.precision_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
print(f"Recall: {metrics.recall_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
print(f"F1: {metrics.f1_score(y_true, y_pred, average='macro', zero_division=0):.3f}")
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3)
print(classification_report)
# Accuracy: 0.410
# Precision: 0.527
# Recall: 0.410
Expand Down

0 comments on commit 1c86d5e

Please sign in to comment.