Skip to content

Commit

Permalink
Merge pull request #460 from tomaarsen/hotfix/python-ify-evaluation-r…
Browse files Browse the repository at this point in the history
…esults

Fix: Python-ify evaluation results before writing model card
  • Loading branch information
tomaarsen committed Dec 7, 2023
2 parents ba6ea7b + 7f95b90 commit 083a489
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
56 changes: 56 additions & 0 deletions docs/source/en/how_to/model_cards.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,59 @@ Carbon emissions were measured using [CodeCarbon](https://github.com/mlco2/codec

- **Carbon Emitted**: 0.003 kg of CO2
- **Hours Used**: 0.072 hours

## Custom Metrics

If you use custom metrics, then these will be included in your model card as well! For example, if you use the following `metric` function:

```py
from setfit import SetFitModel, Trainer, TrainingArguments

...

def compute_metrics(y_pred, y_test):
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

...

trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
metric=compute_metrics,
)
trainer.train()

model.save_pretrained("setfit-bge-small-v1.5-sst2-8-shot")
```

Then the final model card will contain your special metrics! For example, the metadata will include e.g.:

```yaml
metrics:
- type: accuracy
value: 0.8061504667764964
name: Accuracy
- type: precision
value: 0.7293729372937293
name: Precision
- type: recall
value: 0.9724972497249725
name: Recall
- type: f1
value: 0.8335690711928335
name: F1
```

Additionally, the Evaluation section will display your metrics:

<h4>Metrics</h4>

| Label | Accuracy | Precision | Recall | F1 |
|:--------|:---------|:----------|:-------|:-------|
| **all** | 0.8062 | 0.7294 | 0.9725 | 0.8336 |
15 changes: 13 additions & 2 deletions src/setfit/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,20 @@ def set_st_id(self, model_id: str) -> None:
self.st_id = model_id

def post_training_eval_results(self, results: Dict[str, float]) -> None:
self.eval_results_dict = results
def try_to_pure_python(value: Any) -> Any:
"""Try to convert a value from a Numpy or Torch scalar to pure Python, if not already pure Python"""
try:
if hasattr(value, "dtype"):
return value.item()
except Exception:
pass
return value

results_without_split = {key.split("_", maxsplit=1)[1].title(): value for key, value in results.items()}
pure_python_results = {key: try_to_pure_python(value) for key, value in results.items()}
results_without_split = {
key.split("_", maxsplit=1)[1].title(): value for key, value in pure_python_results.items()
}
self.eval_results_dict = pure_python_results
self.metric_lines = [{"Label": "**all**", **results_without_split}]

def _maybe_round(self, v, decimals=4):
Expand Down

0 comments on commit 083a489

Please sign in to comment.