-
Notifications
You must be signed in to change notification settings - Fork 45
/
eval_f1.py
51 lines (40 loc) · 1.3 KB
/
eval_f1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from parlai.scripts.eval_model import eval_model, setup_args as base_setup_args
IS_ORIGINAL = True
def setup_task():
if IS_ORIGINAL:
task_name = 'tasks.convai2transmitter.agents:SelfOriginalTeacher'
else:
task_name = 'tasks.convai2transmitter.agents:SelfRevisedTeacher'
return task_name
def setup_trained_weights():
if IS_ORIGINAL:
weights_name = './tmp/psquare/psqaure_original.model'
else:
weights_name = './tmp/psquare/psqaure_revised.model'
return weights_name
def setup_args(parser=None):
parser = base_setup_args(parser)
task_name = setup_task()
parser.set_defaults(
task=task_name,
datatype='valid',
hide_labels=False,
metrics='f1,bleu',
)
return parser
def eval_f1(opt, print_parser):
report = eval_model(opt, print_parser)
print('============================')
print('Final F1@1: {}, BLEU: {}'.format(report['f1'], report['bleu']))
if __name__ == '__main__':
parser = setup_args()
model_name = setup_trained_weights()
parser.set_params(
model='agents.transmitter.transmitter:TransformerAgent',
model_file=model_name,
gpu=0,
batchsize=10,
beam_size=2
)
opt = parser.parse_args(print_args=False)
eval_f1(opt, print_parser=parser)