Skip to content

Commit

Permalink
reformat code style
Browse files Browse the repository at this point in the history
  • Loading branch information
qidanrui committed Dec 29, 2023
1 parent 15d4e77 commit 44d43fd
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 57 deletions.
29 changes: 19 additions & 10 deletions dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.num_shot = num_shot

def decode_json_file(
self, data_file_list, table_file, db_id_name, output_name, is_multiple_turn=False
self,
data_file_list,
table_file,
db_id_name,
output_name,
is_multiple_turn=False,
):
"""
TO DO:
Expand Down Expand Up @@ -68,9 +73,9 @@ def decode_json_file(
if type(primary_key[j]) == int:
if coloumns[primary_key[j] - 1][0] == i:
source += (
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
+ "\n"
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
+ "\n"
)
# combination primary key
elif type(primary_key[j]) == list:
Expand All @@ -80,10 +85,10 @@ def decode_json_file(
if coloumns[primary_key[j][k] - 1][0] == i:
keys.append(coloumns[primary_key[j][k] - 1][1])
source += (
combine_p +
", ".join(keys)
+ ") are the primary key."
+ "\n"
combine_p
+ ", ".join(keys)
+ ") are the primary key."
+ "\n"
)
else:
print("not support type", type(primary_key[j]))
Expand Down Expand Up @@ -156,7 +161,9 @@ def create_sft_raw_data(self):
self.decode_json_file(
data_file_list=train_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["train_tables_file"]
DATA_PATH,
data_info["data_source"],
data_info["train_tables_file"],
),
db_id_name=data_info["db_id_name"],
output_name=data_info["output_name"],
Expand All @@ -172,7 +179,9 @@ def create_sft_raw_data(self):
self.decode_json_file(
data_file_list=dev_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["dev_tables_file"]
DATA_PATH,
data_info["data_source"],
data_info["dev_tables_file"],
),
db_id_name=data_info["db_id_name"],
output_name=data_info["output_name"],
Expand Down
130 changes: 83 additions & 47 deletions dbgpt_hub/eval/evaluation_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def load_json(dir):
with open(dir, 'r') as j:
with open(dir, "r") as j:
contents = json.loads(j.read())
return contents

Expand All @@ -37,27 +37,28 @@ def execute_sql(predicted_sql, ground_truth, db_path):

def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out):
try:
res = func_timeout(meta_time_out, execute_sql,
args=(predicted_sql, ground_truth, db_place))
res = func_timeout(
meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place)
)
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f'timeout',)]
result = [(f"timeout",)]
res = 0
except Exception as e:
result = [(f'error',)] # possibly len(query) > 512 or not executable
result = [(f"error",)] # possibly len(query) > 512 or not executable
res = 0
# print(result)
# result = str(set([ret[0] for ret in result]))
result = {'sql_idx': idx, 'res': res}
result = {"sql_idx": idx, "res": res}
# print(result)
return result


def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'):
def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"):
clean_sqls = []
db_path_list = []
if mode == 'gpt':
if mode == "gpt":
# sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r'))
# for idx, sql_str in sql_data.items():
# if type(sql_str) == str:
Expand All @@ -74,14 +75,14 @@ def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'):
# sql, db_name = l.split('\t')
clean_sqls.append(l.strip())
# db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
elif mode == 'gt':
elif mode == "gt":
sqls = open(sql_path)
sql_txt = sqls.readlines()
# sql_txt = [sql.split('\t')[0] for sql in sql_txt]
for idx, sql_str in enumerate(sql_txt):
sql, db_name = sql_str.strip().split('\t')
sql, db_name = sql_str.strip().split("\t")
clean_sqls.append(sql)
db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')
db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite")

return clean_sqls, db_path_list

Expand All @@ -90,81 +91,116 @@ def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
pool = mp.Pool(processes=num_cpus)
for i, sql_pair in enumerate(sqls):
predicted_sql, ground_truth = sql_pair
pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out),
callback=result_callback)
pool.apply_async(
execute_model,
args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out),
callback=result_callback,
)
pool.close()
pool.join()


def sort_results(list_of_dicts):
return sorted(list_of_dicts, key=lambda x: x['sql_idx'])
return sorted(list_of_dicts, key=lambda x: x["sql_idx"])


def compute_acc_by_diff(exec_results, diff_json_path):
num_queries = len(exec_results)
results = [res['res'] for res in exec_results]
results = [res["res"] for res in exec_results]
contents = load_json(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []

for i, content in enumerate(contents):
if content['difficulty'] == 'simple':
if content["difficulty"] == "simple":
simple_results.append(exec_results[i])

if content['difficulty'] == 'moderate':
if content["difficulty"] == "moderate":
moderate_results.append(exec_results[i])

if content['difficulty'] == 'challenging':
if content["difficulty"] == "challenging":
challenging_results.append(exec_results[i])

simple_acc = sum([res['res'] for res in simple_results]) / len(simple_results)
moderate_acc = sum([res['res'] for res in moderate_results]) / len(moderate_results)
challenging_acc = sum([res['res'] for res in challenging_results]) / len(challenging_results)
simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results)
moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results)
challenging_acc = sum([res["res"] for res in challenging_results]) / len(
challenging_results
)
all_acc = sum(results) / num_queries
count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries]
return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists
count_lists = [
len(simple_results),
len(moderate_results),
len(challenging_results),
num_queries,
]
return (
simple_acc * 100,
moderate_acc * 100,
challenging_acc * 100,
all_acc * 100,
count_lists,
)


def print_data(score_lists, count_lists):
levels = ['simple', 'moderate', 'challenging', 'total']
levels = ["simple", "moderate", "challenging", "total"]
print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists))

print('====================================== ACCURACY =====================================')
print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists))
print(
"====================================== ACCURACY ====================================="
)
print(
"{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)
)


if __name__ == '__main__':
if __name__ == "__main__":
args_parser = argparse.ArgumentParser()
args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='')
args_parser.add_argument('--ground_truth_path', type=str, required=True, default='')
args_parser.add_argument('--data_mode', type=str, required=True, default='dev')
args_parser.add_argument('--db_root_path', type=str, required=True, default='')
args_parser.add_argument('--num_cpus', type=int, default=1)
args_parser.add_argument('--meta_time_out', type=float, default=30.0)
args_parser.add_argument('--mode_gt', type=str, default='gt')
args_parser.add_argument('--mode_predict', type=str, default='gpt')
args_parser.add_argument('--difficulty', type=str, default='simple')
args_parser.add_argument('--diff_json_path', type=str, default='')
args_parser.add_argument(
"--predicted_sql_path", type=str, required=True, default=""
)
args_parser.add_argument("--ground_truth_path", type=str, required=True, default="")
args_parser.add_argument("--data_mode", type=str, required=True, default="dev")
args_parser.add_argument("--db_root_path", type=str, required=True, default="")
args_parser.add_argument("--num_cpus", type=int, default=1)
args_parser.add_argument("--meta_time_out", type=float, default=30.0)
args_parser.add_argument("--mode_gt", type=str, default="gt")
args_parser.add_argument("--mode_predict", type=str, default="gpt")
args_parser.add_argument("--difficulty", type=str, default="simple")
args_parser.add_argument("--diff_json_path", type=str, default="")
args = args_parser.parse_args()
exec_result = []

pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict,
data_mode=args.data_mode)
pred_queries, db_paths = package_sqls(
args.predicted_sql_path,
args.db_root_path,
mode=args.mode_predict,
data_mode=args.data_mode,
)
# generate gt sqls:
gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt',
data_mode=args.data_mode)
gt_queries, db_paths_gt = package_sqls(
args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode
)

if len(db_paths) == 0:
db_paths = db_paths_gt

query_pairs = list(zip(pred_queries, gt_queries))
run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out)
run_sqls_parallel(
query_pairs,
db_places=db_paths,
num_cpus=args.num_cpus,
meta_time_out=args.meta_time_out,
)
exec_result = sort_results(exec_result)

print('start calculate')
simple_acc, moderate_acc, challenging_acc, acc, count_lists = \
compute_acc_by_diff(exec_result, args.diff_json_path)
print("start calculate")
simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff(
exec_result, args.diff_json_path
)
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists, count_lists)
print('===========================================================================================')
print(
"==========================================================================================="
)
print("Finished evaluation")

0 comments on commit 44d43fd

Please sign in to comment.