forked from aioz-ai/MICCAI21_MMQ
-
Notifications
You must be signed in to change notification settings - Fork 6
/
evaluation_script.py
110 lines (99 loc) · 3.33 KB
/
evaluation_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from utils import *
import math
import json
def bleu(candidate, references, n, weights):
pn = []
bp = brevity_penalty(candidate, references)
for i in range(n):
pn.append(modified_precision(candidate, references, i + 1))
if len(weights) > len(pn):
tmp_weights = []
for i in range(len(pn)):
tmp_weights.append(weights[i])
bleu_result = calculate_bleu(tmp_weights, pn, n, bp)
print("(warning: the length of weights is bigger than n)")
return bleu_result
elif len(weights) < len(pn):
tmp_weights = []
for i in range(len(pn)):
tmp_weights.append(0)
for i in range(len(weights)):
tmp_weights[i] = weights[i]
bleu_result = calculate_bleu(tmp_weights, pn, n, bp)
print("(warning: the length of weights is smaller than n)")
return bleu_result
else:
bleu_result = calculate_bleu(weights, pn, n, bp)
return bleu_result
# BLEU
def calculate_bleu(weights, pn, n, bp):
sum_wlogp = 0
for i in range(n):
if pn[i] != 0:
sum_wlogp += float(weights[i]) * math.log(pn[i])
bleu_result = bp * math.exp(sum_wlogp)
return bleu_result
# Exact match
def calculate_exactmatch(candidate, reference):
candidate_words = split_sentence(candidate, 1)
reference_words = split_sentence(reference, 1)
count = 0
total = 0
for word in reference_words:
if word in candidate_words:
count += 1
for word in candidate_words:
total += candidate_words[word]
if total == 0:
return "0 (warning: length of candidate's words is 0)"
else:
return count / total
# F1
def calculate_f1score(candidate, reference):
candidate_words = split_sentence(candidate, 1)
reference_words = split_sentence(reference, 1)
word_set = set()
for word in candidate_words:
word_set.add(word)
for word in reference_words:
word_set.add(word)
tp = 0
fp = 0
fn = 0
for word in word_set:
if word in candidate_words and word in reference_words:
tp += candidate_words[word]
elif word in candidate_words and word not in reference_words:
fp += candidate_words[word]
elif word not in candidate_words and word in reference_words:
fn += reference_words[word]
if len(candidate_words) == 0:
return "0 (warning: length of candidate's words is 0)"
elif len(reference_words) == 0:
return 0
else:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
if tp == 0:
return 0
else:
return 2 * precision * recall / (precision + recall)
if __name__ == "__main__":
n = 4
weights = [0.25, 0.25, 0.25, 0.25]
answer_file = "results/BAN_maml_256/answer_list.json"
answer_list = json.load(open(answer_file, 'r'))
BLEU = [0.0, 0.0, 0.0]
f1_score = 0.0
count = 0
for i in answer_list:
if i['answer_type'] != "yes/no":
count+=1
BLEU[0]+=bleu(i['predict'], [i['ref']], 1, weights)
BLEU[1]+=bleu(i['predict'], [i['ref']], 2, weights)
BLEU[2] += bleu(i['predict'], [i['ref']], 3, weights)
f1_score += calculate_f1score(i['predict'], i['ref'])
BLEU = BLEU / count
print(BLEU)
f1_score = f1_score / count
print(f1_score)