-
Notifications
You must be signed in to change notification settings - Fork 168
/
predict.py
executable file
·137 lines (123 loc) · 6.27 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
# Use existing model to predict sql from tables and questions.
#
# For example, you can get a pretrained model from https://github.com/naver/sqlova/releases:
# https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_bert_best.pt
# https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_best.pt
#
# Make sure you also have the following support files (see README for where to get them):
# - bert_config_uncased_*.json
# - vocab_uncased_*.txt
#
# Finally, you need some data - some files called:
# - <split>.db
# - <split>.jsonl
# - <split>.tables.jsonl
# - <split>_tok.jsonl # derived using annotate_ws.py
# You can play with the existing train/dev/test splits, or make your own with
# the add_csv.py and add_question.py utilities.
#
# Once you have all that, you are ready to predict, using:
# python predict.py \
# --bert_type_abb uL \ # need to match the architecture of the model you are using
# --model_file <path to models>/model_best.pt \
# --bert_model_file <path to models>/model_bert_best.pt \
# --bert_path <path to bert_config/vocab> \
# --result_path <where to place results> \
# --data_path <path to db/jsonl/tables.jsonl> \
# --split <split>
#
# Results will be in a file called results_<split>.jsonl in the result_path.
import argparse, os
from sqlnet.dbengine import DBEngine
from sqlova.utils.utils_wikisql import *
from train import construct_hyper_param, get_models
# This is a stripped down version of the test() method in train.py - identical, except:
# - does not attempt to measure accuracy and indeed does not expect the data to be labelled.
# - saves plain text sql queries.
#
def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
max_seq_length,
num_target_layers, detail=False, st_pos=0, cnt_tot=1, EG=False, beam_size=4,
path_db=None, dset_name='test'):
model.eval()
model_bert.eval()
engine = DBEngine(os.path.join(path_db, f"{dset_name}.db"))
results = []
for iB, t in enumerate(data_loader):
nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
g_sc, g_sa, g_wn, g_wc, g_wo, g_wv = get_g(sql_i)
g_wvi_corenlp = get_g_wvi_corenlp(t)
wemb_n, wemb_h, l_n, l_hpu, l_hs, \
nlu_tt, t_to_tt_idx, tt_to_t_idx \
= get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
if not EG:
# No Execution guided decoding
s_sc, s_sa, s_wn, s_wc, s_wo, s_wv = model(wemb_n, l_n, wemb_h, l_hpu, l_hs)
pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wvi = pred_sw_se(s_sc, s_sa, s_wn, s_wc, s_wo, s_wv, )
pr_wv_str, pr_wv_str_wp = convert_pr_wvi_to_string(pr_wvi, nlu_t, nlu_tt, tt_to_t_idx, nlu)
pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_wn, pr_wc, pr_wo, pr_wv_str, nlu)
else:
# Execution guided decoding
prob_sca, prob_w, prob_wn_w, pr_sc, pr_sa, pr_wn, pr_sql_i = model.beam_forward(wemb_n, l_n, wemb_h, l_hpu,
l_hs, engine, tb,
nlu_t, nlu_tt,
tt_to_t_idx, nlu,
beam_size=beam_size)
# sort and generate
pr_wc, pr_wo, pr_wv, pr_sql_i = sort_and_generate_pr_w(pr_sql_i)
# Following variables are just for consistency with no-EG case.
pr_wvi = None # not used
pr_wv_str=None
pr_wv_str_wp=None
pr_sql_q = generate_sql_q(pr_sql_i, tb)
for b, (pr_sql_i1, pr_sql_q1) in enumerate(zip(pr_sql_i, pr_sql_q)):
results1 = {}
results1["query"] = pr_sql_i1
results1["table_id"] = tb[b]["id"]
results1["nlu"] = nlu[b]
results1["sql"] = pr_sql_q1
results.append(results1)
return results
## Set up hyper parameters and paths
parser = argparse.ArgumentParser()
parser.add_argument("--model_file", required=True, help='model file to use (e.g. model_best.pt)')
parser.add_argument("--bert_model_file", required=True, help='bert model file to use (e.g. model_bert_best.pt)')
parser.add_argument("--bert_path", required=True, help='path to bert files (bert_config*.json etc)')
parser.add_argument("--data_path", required=True, help='path to *.jsonl and *.db files')
parser.add_argument("--split", required=True, help='prefix of jsonl and db files (e.g. dev)')
parser.add_argument("--result_path", required=True, help='directory in which to place results')
args = construct_hyper_param(parser)
BERT_PT_PATH = args.bert_path
path_save_for_evaluation = args.result_path
# Load pre-trained models
path_model_bert = args.bert_model_file
path_model = args.model_file
args.no_pretraining = True # counterintuitive, but avoids loading unused models
model, model_bert, tokenizer, bert_config = get_models(args, BERT_PT_PATH, trained=True, path_model_bert=path_model_bert, path_model=path_model)
# Load data
dev_data, dev_table = load_wikisql_data(args.data_path, mode=args.split, toy_model=args.toy_model, toy_size=args.toy_size, no_hs_tok=True)
dev_loader = torch.utils.data.DataLoader(
batch_size=args.bS,
dataset=dev_data,
shuffle=False,
num_workers=1,
collate_fn=lambda x: x # now dictionary values are not merged!
)
# Run prediction
with torch.no_grad():
results = predict(dev_loader,
dev_table,
model,
model_bert,
bert_config,
tokenizer,
args.max_seq_length,
args.num_target_layers,
detail=False,
path_db=args.data_path,
st_pos=0,
dset_name=args.split, EG=args.EG)
# Save results
save_for_evaluation(path_save_for_evaluation, results, args.split)