-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathcompute_score.py
67 lines (61 loc) · 2.63 KB
/
compute_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import editdistance
from collections import defaultdict
from utils import Tools
def compute_EM(target, predictions, passk):
target_lines = [line.strip() for line in target.splitlines() if line.strip()]
EM_scores = []
for prediction in predictions[:passk]:
prediction_lines = [line.strip() for line in prediction.splitlines() if line.strip()][:len(target_lines)]
if len(target_lines) != len(prediction_lines):
EM_scores.append(0)
continue
if target_lines == prediction_lines:
EM_scores.append(1)
continue
EM_scores.append(0)
return any(EM_scores)
def compute_ES(target, predictions, passk):
target_lines = [line.strip() for line in target.splitlines() if line.strip()]
target_str = '\n'.join(target_lines)
ES_scores = []
for prediction in predictions[:passk]:
prediction_lines = [line.strip() for line in prediction.splitlines() if line.strip()][:len(target_lines)]
prediction_str = '\n'.join(prediction_lines)
ES_scores.append(
1 - (editdistance.eval(target_str, prediction_str) / max(len(target_str), len(prediction_str)))
)
return max(ES_scores)
def compute_score_by_repo_with_metadata(repos, lines, stype, passk=1):
scores = defaultdict(list)
for line in lines:
repo = line['metadata']['task_id'].split('/')[0]
if repo not in repos:
continue
samples = [line['choices'][i]['text'] for i in range(len(line['choices']))]
if stype == 'EM':
score = compute_EM(line['metadata']['ground_truth'], samples, passk)
elif stype == 'ES':
score = compute_ES(line['metadata']['ground_truth'], samples, passk)
scores[repo].append(score)
avg_scores = {repo: round(sum(scores[repo]) / len(scores[repo]), 4) for repo in scores}
repo_count = {repo: len(scores[repo]) for repo in scores}
print(stype)
for repo in avg_scores.keys():
print(f'{avg_scores[repo]}\t{repo_count[repo]}\t{repo}')
if __name__ == '__main__':
repos = [
'huggingface_diffusers',
'nerfstudio-project_nerfstudio',
'awslabs_fortuna',
'huggingface_evaluate',
'google_vizier',
'alibaba_FederatedScope',
'pytorch_rl',
'opendilab_ACE',
]
'''compute single prediction'''
file_path = 'output/line-rgrg-ada-ws-20-ss-2_samples.0.jsonl'
compute_score_by_repo_with_metadata(repos, Tools.load_jsonl(file_path), 'EM', passk=1)
compute_score_by_repo_with_metadata(repos, Tools.load_jsonl(file_path), 'ES', passk=1)