forked from THU-KEG/MAVEN-dataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_submission.py
27 lines (27 loc) · 1.03 KB
/
get_submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
import sys
import argparse
import json
import numpy as np
if __name__=='__main__':
parser=argparse.ArgumentParser()
parser.add_argument("--test_data",default="../maven/test.jsonl",help="path to the test data file",required=False)
parser.add_argument("--preds",default="MAVEN/_preds.npy",help="path to the prediction file generated by the run_MAVEN_infer.sh script")
parser.add_argument("--output",default="../maven/results.jsonl",help="path to the output file")
args=parser.parse_args()
preds=np.load(args.preds)
fout=open(args.output,"w")
with open(args.test_data,"r") as fin:
lines=fin.readlines()
Cnt=0
for line in lines:
data=json.loads(line)
res={"id":data['id']}
tmp=[]
for mention in data['candidates']:
tmp.append({"id":mention["id"],"type_id":int(preds[Cnt])})
Cnt+=1
res["predictions"]=tmp
fout.write(json.dumps(res)+"\n")
assert Cnt==len(preds)
fout.close()