diff --git a/src/side/glue_results.py b/src/side/glue_results.py index a1acd64..0b3cb08 100644 --- a/src/side/glue_results.py +++ b/src/side/glue_results.py @@ -1,9 +1,12 @@ import argparse import os import re +import numpy as np import json +import jsonlines import random import evaluate +from pathlib import Path from datasets import load_dataset from collections import defaultdict @@ -227,11 +230,18 @@ def rte_get_pred(pred): ) eval_metric = metric.compute() - for key in list(metric.keys()): + for key in list(eval_metric.keys()): eval_metric[f"{key}_{task_name}"] = eval_metric[key] del eval_metric[key] eval_metrics.update(eval_metric) + + average = [] + for key in ["matthews_correlation_cola", "accuracy_sst2", "f1_mrpc", "spearmanr_stsb", "f1_qqp", "accuracy_mnli", "accuracy_qnli", "accuracy_wnli"]: + average.append(eval_metrics[key]) + + average = np.mean(average) + eval_metrics["average"] = average print(len(labels), eval_metrics)