-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_bert_score.py
95 lines (81 loc) · 3.21 KB
/
get_bert_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
import sys
import random
import ast
import math
from random import randint
import codecs
import json
import pickle
import bert_score
from bert_score import score
from rouge_score import rouge_scorer
import rouge
import ast
import nltk
import spacy
nlp = spacy.load("en_core_web_sm")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("picklefile",nargs='?')
parser.add_argument("txtfile",nargs='?')
parser.add_argument("labelfile",nargs='?')
args = parser.parse_args()
make_data(args.txtfile,args.picklefile, args.labelfile)
def make_data(data, test_pickle, label):
labels = []
with open(label,'r') as f:
for line in f:
line = line.replace('\n','')
line = line.split('\t')
labels.append(line)
numbers = []
with open(test_pickle,'r') as f:
for line in f:
line = ast.literal_eval(line)
s = line["beams"][1]
numbers.append(s)
count = 0
count_s = 0
obs = ''
correct = []
wrong = []
pred = {}
orig = {}
refs = []
obs1 = []
obs2 = []
with open(data,'r') as f:
for line in f:
line = json.loads(line)
line["obs1"] = line["obs1"].replace("\n","")
line["obs2"] = line["obs2"].replace("\n","")
orig[str(line["obs1"]+line["obs2"])] = []
pred[str(line["obs1"]+line["obs2"])] = []
with open(data,'r') as f:
for line in f:
line = json.loads(line)
line["obs1"] = line["obs1"].replace("\n","")
line["obs2"] = line["obs2"].replace("\n","")
if line["label"]=='1':
correct.append(str(numbers[count_s].replace('["eos"]', '')))
pred[str(line["obs1"]+line["obs2"])]=[str(numbers[count_s].replace('["eos"]', ''))]
obs1.append(str(line["obs1"].replace("\n","")))
obs2.append(str(line["obs2"].replace("\n","")))
refs.append(str(line["hyp1"].replace("\n","")))
orig[str(line["obs1"].replace("\n","")) + str(line["obs2"].replace("\n",""))].append(str(line["hyp1"].replace("\n","").lower()))
elif line["label"]=='2':
correct.append(str(numbers[count_s].replace('["eos"]', '')))
pred[str(line["obs1"].replace('\n','')+line["obs2"].replace('\n',''))] = [str(numbers[count_s].replace('["eos"]', ''))]
obs1.append(str(line["obs1"].replace("\n","")))
obs2.append(str(line["obs2"].replace("\n","")))
refs.append(str(line["hyp2"].replace("\n","")))
orig[str(line["obs1"].replace("\n","")) + str(line["obs2"].replace("\n",""))].append(str(line["hyp2"].replace("\n","").lower()))
count_s = count_s+1
P_cor, R_cor, F_cor = score(correct, refs, lang="en", model_type="bert-base-uncased", verbose=True)
cor = F_cor.tolist()
print(F_cor.mean())
return
if __name__ =='__main__':
main()