-
Notifications
You must be signed in to change notification settings - Fork 6
/
gen_inference_data.py
53 lines (47 loc) · 1.43 KB
/
gen_inference_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
import copy
import json
prompt_temple = {
"data_name": "spider",
"id": 0,
"db_id": "",
"chat_rounds": [
{
"role": "system",
"content": "",
"chat_round_id": 0
},
{
"role": "human",
"content": "",
"chat_round_id": 1
},
{
"role": "bot",
"content": "",
"chat_round_id": 2
}
]
}
def resdsql_insider(schemalinking_result):
results = []
with open(schemalinking_result) as f:
redsql_trains = json.load(f)
for redsql_train in redsql_trains:
_, sql = redsql_train["output_sequence"].split("|", 1)
prompt_temple["db_id"] = redsql_train["db_id"]
prompt_temple["chat_rounds"][1]["content"] = redsql_train["input_sequence"]
prompt_temple["chat_rounds"][2]["content"] = sql
results.append(copy.deepcopy(prompt_temple))
return results
def main(opt):
prompts = resdsql_insider(opt.schemalinking_result)
with open(opt.output, 'w', encoding='utf-8') as f:
json.dump(prompts, f, indent=4, ensure_ascii=False)
f.flush()
if __name__ == "__main__":
parser_arg = argparse.ArgumentParser("")
parser_arg.add_argument('--schemalinking_result', type=str, default="")
parser_arg.add_argument('--output', type=str, default="")
opt = parser_arg.parse_args()
main(opt)