From d70ab214213d2b4953a1a652ed0ac9418f647d68 Mon Sep 17 00:00:00 2001 From: moutozf <3150102181@zju.edu.cn> Date: Mon, 4 Mar 2024 20:41:19 +0800 Subject: [PATCH 1/3] add ves and exactly match scores for bird dataset --- dbgpt_hub/eval/evaluation_bird.py | 193 +++++++++++++++++------------- 1 file changed, 107 insertions(+), 86 deletions(-) diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py index b84744a..2f94d09 100644 --- a/dbgpt_hub/eval/evaluation_bird.py +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -2,17 +2,19 @@ do evaluate about the predict sql in dataset BIRD,compare with default dev.sql --db """ - import sys import json import argparse import sqlite3 import multiprocessing as mp from func_timeout import func_timeout, FunctionTimedOut +import math +import time + def load_json(dir): - with open(dir, "r") as j: + with open(dir, 'r') as j: contents = json.loads(j.read()) return contents @@ -25,40 +27,47 @@ def execute_sql(predicted_sql, ground_truth, db_path): conn = sqlite3.connect(db_path) # Connect to the database cursor = conn.cursor() + pred_start_time = time.time() cursor.execute(predicted_sql) + pred_exec_time = time.time() - pred_start_time predicted_res = cursor.fetchall() + true_start_time = time.time() cursor.execute(ground_truth) + true_exec_time = time.time() - true_start_time ground_truth_res = cursor.fetchall() res = 0 + time_ratio = 0 if set(predicted_res) == set(ground_truth_res): res = 1 - return res + time_ratio = true_exec_time/pred_exec_time if pred_exec_time > 0 else 0 + return res, time_ratio def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): try: - res = func_timeout( - meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place) - ) + res, time_ratio = func_timeout(meta_time_out, execute_sql, + args=(predicted_sql, ground_truth, db_place)) except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: - result = [(f"timeout",)] + result = [(f'timeout',)] res = 0 + time_ratio = 0 except Exception as e: - result = [(f"error",)] # possibly len(query) > 512 or not executable + result = [(f'error',)] # possibly len(query) > 512 or not executable res = 0 + time_ratio = 0 # print(result) # result = str(set([ret[0] for ret in result])) - result = {"sql_idx": idx, "res": res} + result = {'sql_idx': idx, 'res': res, "match": int(predicted_sql == ground_truth), "time_ratio": time_ratio} # print(result) return result -def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): +def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): clean_sqls = [] db_path_list = [] - if mode == "gpt": + if mode == 'gpt': # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) # for idx, sql_str in sql_data.items(): # if type(sql_str) == str: @@ -75,14 +84,14 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): # sql, db_name = l.split('\t') clean_sqls.append(l.strip()) # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') - elif mode == "gt": + elif mode == 'gt': sqls = open(sql_path) sql_txt = sqls.readlines() # sql_txt = [sql.split('\t')[0] for sql in sql_txt] for idx, sql_str in enumerate(sql_txt): - sql, db_name = sql_str.strip().split("\t") + sql, db_name = sql_str.strip().split('\t') clean_sqls.append(sql) - db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") + db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') return clean_sqls, db_path_list @@ -101,106 +110,118 @@ def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): def sort_results(list_of_dicts): - return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) + return sorted(list_of_dicts, key=lambda x: x['sql_idx']) + + +def compute_ves(exec_results): + num_queries = len(exec_results) + total_ratio = 0 + count = 0 + for i, result in enumerate(exec_results): + if result['time_ratio'] != 0: + count += 1 + total_ratio += math.sqrt(result['time_ratio']) * 100 + ves = (total_ratio/num_queries) + return ves -def compute_acc_by_diff(exec_results, diff_json_path): + + +def compute_acc_by_diff(exec_results, diff_json_path, metric): num_queries = len(exec_results) - results = [res["res"] for res in exec_results] + results = [res[metric] for res in exec_results] contents = load_json(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] for i, content in enumerate(contents): - if content["difficulty"] == "simple": + if content['difficulty'] == 'simple': simple_results.append(exec_results[i]) - if content["difficulty"] == "moderate": + if content['difficulty'] == 'moderate': moderate_results.append(exec_results[i]) - if content["difficulty"] == "challenging": + if content['difficulty'] == 'challenging': challenging_results.append(exec_results[i]) - - simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) - moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) - challenging_acc = sum([res["res"] for res in challenging_results]) / len( - challenging_results - ) - all_acc = sum(results) / num_queries - count_lists = [ - len(simple_results), - len(moderate_results), - len(challenging_results), - num_queries, - ] - return ( - simple_acc * 100, - moderate_acc * 100, - challenging_acc * 100, - all_acc * 100, - count_lists, - ) - - -def print_data(score_lists, count_lists): - levels = ["simple", "moderate", "challenging", "total"] + if metric in ["res", "match"]: + simple_acc = sum([res[metric] for res in simple_results]) / len(simple_results) + moderate_acc = sum([res[metric] for res in moderate_results]) / len(moderate_results) + challenging_acc = sum([res[metric] for res in challenging_results]) / len(challenging_results) + all_acc = sum(results) / num_queries + elif metric in ["time_ratio"]: + simple_acc = compute_ves(simple_results) + moderate_acc = compute_ves(moderate_results) + challenging_acc = compute_ves(challenging_results) + all_acc = compute_ves(exec_results) + else: + raise NotImplementedError(f"metric: {metric} is not supported") + count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] + if metric in ["res", "match"]: + return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists + else: + return simple_acc, moderate_acc, challenging_acc, all_acc, count_lists + + +def print_data(score_lists, count_lists, metric="Exec ACCURACY"): + levels = ['simple', 'moderate', 'challenging', 'total'] print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) - print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) - print( - "====================================== ACCURACY =====================================" - ) - print( - "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) - ) + print(f'====================================== {metric} =====================================') + print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) -if __name__ == "__main__": +if __name__ == '__main__': args_parser = argparse.ArgumentParser() - args_parser.add_argument( - "--predicted_sql_path", type=str, required=True, default="" - ) - args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") - args_parser.add_argument("--data_mode", type=str, required=True, default="dev") - args_parser.add_argument("--db_root_path", type=str, required=True, default="") - args_parser.add_argument("--num_cpus", type=int, default=1) - args_parser.add_argument("--meta_time_out", type=float, default=30.0) - args_parser.add_argument("--mode_gt", type=str, default="gt") - args_parser.add_argument("--mode_predict", type=str, default="gpt") - args_parser.add_argument("--difficulty", type=str, default="simple") - args_parser.add_argument("--diff_json_path", type=str, default="") + args_parser.add_argument('--predicted_sql_path', type=str, default='../../pred_sql/pred_sql_bird_qwen14b_1212.sql') + args_parser.add_argument('--ground_truth_path', type=str, default='../../dbgpt_hub/data/bird/dev/dev.sql') + args_parser.add_argument('--data_mode', type=str, default='dev') + args_parser.add_argument('--db_root_path', type=str, default='../../dbgpt_hub/data/bird/dev/dev_databases/') + args_parser.add_argument('--num_cpus', type=int, default=1) + args_parser.add_argument('--meta_time_out', type=float, default=30.0) + args_parser.add_argument('--mode_gt', type=str, default='gt') + args_parser.add_argument('--mode_predict', type=str, default='gpt') + args_parser.add_argument('--difficulty', type=str, default='simple') + args_parser.add_argument('--diff_json_path', type=str, default='') + args_parser.add_argument("--etype", dest="etype", type=str, default="match", choices=("all", "exec", "match", "ves"),) + args = args_parser.parse_args() exec_result = [] - pred_queries, db_paths = package_sqls( - args.predicted_sql_path, - args.db_root_path, - mode=args.mode_predict, - data_mode=args.data_mode, - ) + pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, + data_mode=args.data_mode) # generate gt sqls: - gt_queries, db_paths_gt = package_sqls( - args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode - ) + gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', + data_mode=args.data_mode) if len(db_paths) == 0: db_paths = db_paths_gt query_pairs = list(zip(pred_queries, gt_queries)) - run_sqls_parallel( - query_pairs, - db_places=db_paths, - num_cpus=args.num_cpus, - meta_time_out=args.meta_time_out, - ) + if args.etype in ["all", "exec", "ves"]: + run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) + else: + for i, sql_pair in enumerate(query_pairs): + predicted_sql, ground_truth = sql_pair + exec_result.append({'sql_idx': i, 'match': int(predicted_sql == ground_truth)}) exec_result = sort_results(exec_result) - print("start calculate") - simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( - exec_result, args.diff_json_path - ) - score_lists = [simple_acc, moderate_acc, challenging_acc, acc] - print_data(score_lists, count_lists) + print('start calculate') + if args.etype in ["all", "exec"]: + simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ + compute_acc_by_diff(exec_result, args.diff_json_path, "res") + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists, metric="Exec Accuracy") + if args.etype in ["all", "match"]: + simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ + compute_acc_by_diff(exec_result, args.diff_json_path, "match") + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists, metric="Match Accuracy") + if args.etype in ["all", "ves"]: + simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ + compute_acc_by_diff(exec_result, args.diff_json_path, "time_ratio") + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists, metric="Ves") print( - "===========================================================================================" + '===========================================================================================' ) print("Finished evaluation") From 68aa8d3f08809040bacf265c502a908e51cef244 Mon Sep 17 00:00:00 2001 From: moutozf <3150102181@zju.edu.cn> Date: Mon, 4 Mar 2024 20:56:42 +0800 Subject: [PATCH 2/3] add ves and exactly match scores for bird dataset and reformat --- dbgpt_hub/eval/evaluation_bird.py | 67 +++++++++++++++---------------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py index 2f94d09..40c9852 100644 --- a/dbgpt_hub/eval/evaluation_bird.py +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -12,9 +12,8 @@ import time - def load_json(dir): - with open(dir, 'r') as j: + with open(dir, "r") as j: contents = json.loads(j.read()) return contents @@ -50,25 +49,25 @@ def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: - result = [(f'timeout',)] + result = [(f"timeout",)] res = 0 time_ratio = 0 except Exception as e: - result = [(f'error',)] # possibly len(query) > 512 or not executable + result = [(f"error",)] # possibly len(query) > 512 or not executable res = 0 time_ratio = 0 # print(result) # result = str(set([ret[0] for ret in result])) - result = {'sql_idx': idx, 'res': res, "match": int(predicted_sql == ground_truth), "time_ratio": time_ratio} + result = {"sql_idx": idx, "res": res, "match": int(predicted_sql == ground_truth), "time_ratio": time_ratio} # print(result) return result -def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): +def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): clean_sqls = [] db_path_list = [] - if mode == 'gpt': - # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) + if mode == "gpt": + # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', "r')) # for idx, sql_str in sql_data.items(): # if type(sql_str) == str: # sql, db_name = sql_str.split('\t----- bird -----\t') @@ -84,14 +83,14 @@ def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): # sql, db_name = l.split('\t') clean_sqls.append(l.strip()) # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') - elif mode == 'gt': + elif mode == "gt": sqls = open(sql_path) sql_txt = sqls.readlines() # sql_txt = [sql.split('\t')[0] for sql in sql_txt] for idx, sql_str in enumerate(sql_txt): - sql, db_name = sql_str.strip().split('\t') + sql, db_name = sql_str.strip().split("\t") clean_sqls.append(sql) - db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") return clean_sqls, db_path_list @@ -110,7 +109,7 @@ def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): def sort_results(list_of_dicts): - return sorted(list_of_dicts, key=lambda x: x['sql_idx']) + return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) def compute_ves(exec_results): @@ -119,9 +118,9 @@ def compute_ves(exec_results): count = 0 for i, result in enumerate(exec_results): - if result['time_ratio'] != 0: + if result["time_ratio"] != 0: count += 1 - total_ratio += math.sqrt(result['time_ratio']) * 100 + total_ratio += math.sqrt(result["time_ratio"]) * 100 ves = (total_ratio/num_queries) return ves @@ -134,13 +133,13 @@ def compute_acc_by_diff(exec_results, diff_json_path, metric): simple_results, moderate_results, challenging_results = [], [], [] for i, content in enumerate(contents): - if content['difficulty'] == 'simple': + if content["difficulty"] == "simple": simple_results.append(exec_results[i]) - if content['difficulty'] == 'moderate': + if content["difficulty"] == "moderate": moderate_results.append(exec_results[i]) - if content['difficulty'] == 'challenging': + if content["difficulty"] == "challenging": challenging_results.append(exec_results[i]) if metric in ["res", "match"]: simple_acc = sum([res[metric] for res in simple_results]) / len(simple_results) @@ -162,26 +161,26 @@ def compute_acc_by_diff(exec_results, diff_json_path, metric): def print_data(score_lists, count_lists, metric="Exec ACCURACY"): - levels = ['simple', 'moderate', 'challenging', 'total'] + levels = ["simple", "moderate", "challenging", "total"] print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) print(f'====================================== {metric} =====================================') - print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) + print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)) -if __name__ == '__main__': +if __name__ == "__main__": args_parser = argparse.ArgumentParser() - args_parser.add_argument('--predicted_sql_path', type=str, default='../../pred_sql/pred_sql_bird_qwen14b_1212.sql') - args_parser.add_argument('--ground_truth_path', type=str, default='../../dbgpt_hub/data/bird/dev/dev.sql') - args_parser.add_argument('--data_mode', type=str, default='dev') - args_parser.add_argument('--db_root_path', type=str, default='../../dbgpt_hub/data/bird/dev/dev_databases/') - args_parser.add_argument('--num_cpus', type=int, default=1) - args_parser.add_argument('--meta_time_out', type=float, default=30.0) - args_parser.add_argument('--mode_gt', type=str, default='gt') - args_parser.add_argument('--mode_predict', type=str, default='gpt') - args_parser.add_argument('--difficulty', type=str, default='simple') - args_parser.add_argument('--diff_json_path', type=str, default='') + args_parser.add_argument("--predicted_sql_path", type=str, default="../../pred_sql/pred_sql_bird_qwen14b_1212.sql") + args_parser.add_argument("--ground_truth_path", type=str, default="../../dbgpt_hub/data/bird/dev/dev.sql") + args_parser.add_argument("--data_mode", type=str, default="dev") + args_parser.add_argument("--db_root_path", type=str, default="../../dbgpt_hub/data/bird/dev/dev_databases/") + args_parser.add_argument("--num_cpus", type=int, default=1) + args_parser.add_argument("--meta_time_out", type=float, default=30.0) + args_parser.add_argument("--mode_gt", type=str, default="gt") + args_parser.add_argument("--mode_predict", type=str, default="gpt") + args_parser.add_argument("--difficulty", type=str, default="simple") + args_parser.add_argument("--diff_json_path", type=str, default="") args_parser.add_argument("--etype", dest="etype", type=str, default="match", choices=("all", "exec", "match", "ves"),) args = args_parser.parse_args() @@ -190,7 +189,7 @@ def print_data(score_lists, count_lists, metric="Exec ACCURACY"): pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, data_mode=args.data_mode) # generate gt sqls: - gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', + gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode) if len(db_paths) == 0: @@ -202,10 +201,10 @@ def print_data(score_lists, count_lists, metric="Exec ACCURACY"): else: for i, sql_pair in enumerate(query_pairs): predicted_sql, ground_truth = sql_pair - exec_result.append({'sql_idx': i, 'match': int(predicted_sql == ground_truth)}) + exec_result.append({"sql_idx": i, "match": int(predicted_sql == ground_truth)}) exec_result = sort_results(exec_result) - print('start calculate') + print("start calculate") if args.etype in ["all", "exec"]: simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ compute_acc_by_diff(exec_result, args.diff_json_path, "res") @@ -222,6 +221,6 @@ def print_data(score_lists, count_lists, metric="Exec ACCURACY"): score_lists = [simple_acc, moderate_acc, challenging_acc, acc] print_data(score_lists, count_lists, metric="Ves") print( - '===========================================================================================' + "===========================================================================================" ) print("Finished evaluation") From 41f11b688a365b5c95b9c5ccd665786b903d0564 Mon Sep 17 00:00:00 2001 From: moutozf <3150102181@zju.edu.cn> Date: Tue, 5 Mar 2024 22:18:20 +0800 Subject: [PATCH 3/3] reformat it by black --- dbgpt_hub/eval/evaluation_bird.py | 125 +++++++++++++++++++++++------- 1 file changed, 96 insertions(+), 29 deletions(-) diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py index 40c9852..aee545e 100644 --- a/dbgpt_hub/eval/evaluation_bird.py +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -38,14 +38,15 @@ def execute_sql(predicted_sql, ground_truth, db_path): time_ratio = 0 if set(predicted_res) == set(ground_truth_res): res = 1 - time_ratio = true_exec_time/pred_exec_time if pred_exec_time > 0 else 0 + time_ratio = true_exec_time / pred_exec_time if pred_exec_time > 0 else 0 return res, time_ratio def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): try: - res, time_ratio = func_timeout(meta_time_out, execute_sql, - args=(predicted_sql, ground_truth, db_place)) + res, time_ratio = func_timeout( + meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place) + ) except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: @@ -58,7 +59,12 @@ def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): time_ratio = 0 # print(result) # result = str(set([ret[0] for ret in result])) - result = {"sql_idx": idx, "res": res, "match": int(predicted_sql == ground_truth), "time_ratio": time_ratio} + result = { + "sql_idx": idx, + "res": res, + "match": int(predicted_sql == ground_truth), + "time_ratio": time_ratio, + } # print(result) return result @@ -121,11 +127,10 @@ def compute_ves(exec_results): if result["time_ratio"] != 0: count += 1 total_ratio += math.sqrt(result["time_ratio"]) * 100 - ves = (total_ratio/num_queries) + ves = total_ratio / num_queries return ves - def compute_acc_by_diff(exec_results, diff_json_path, metric): num_queries = len(exec_results) results = [res[metric] for res in exec_results] @@ -143,8 +148,12 @@ def compute_acc_by_diff(exec_results, diff_json_path, metric): challenging_results.append(exec_results[i]) if metric in ["res", "match"]: simple_acc = sum([res[metric] for res in simple_results]) / len(simple_results) - moderate_acc = sum([res[metric] for res in moderate_results]) / len(moderate_results) - challenging_acc = sum([res[metric] for res in challenging_results]) / len(challenging_results) + moderate_acc = sum([res[metric] for res in moderate_results]) / len( + moderate_results + ) + challenging_acc = sum([res[metric] for res in challenging_results]) / len( + challenging_results + ) all_acc = sum(results) / num_queries elif metric in ["time_ratio"]: simple_acc = compute_ves(simple_results) @@ -153,9 +162,20 @@ def compute_acc_by_diff(exec_results, diff_json_path, metric): all_acc = compute_ves(exec_results) else: raise NotImplementedError(f"metric: {metric} is not supported") - count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] + count_lists = [ + len(simple_results), + len(moderate_results), + len(challenging_results), + num_queries, + ] if metric in ["res", "match"]: - return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists + return ( + simple_acc * 100, + moderate_acc * 100, + challenging_acc * 100, + all_acc * 100, + count_lists, + ) else: return simple_acc, moderate_acc, challenging_acc, all_acc, count_lists @@ -163,61 +183,108 @@ def compute_acc_by_diff(exec_results, diff_json_path, metric): def print_data(score_lists, count_lists, metric="Exec ACCURACY"): levels = ["simple", "moderate", "challenging", "total"] print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) - print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) - print(f'====================================== {metric} =====================================') - print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists)) + print( + f"====================================== {metric} =====================================" + ) + print( + "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) + ) if __name__ == "__main__": args_parser = argparse.ArgumentParser() - args_parser.add_argument("--predicted_sql_path", type=str, default="../../pred_sql/pred_sql_bird_qwen14b_1212.sql") - args_parser.add_argument("--ground_truth_path", type=str, default="../../dbgpt_hub/data/bird/dev/dev.sql") + args_parser.add_argument( + "--predicted_sql_path", + type=str, + default="../../pred_sql/pred_sql_bird_qwen14b_1212.sql", + ) + args_parser.add_argument( + "--ground_truth_path", type=str, default="../../dbgpt_hub/data/bird/dev/dev.sql" + ) args_parser.add_argument("--data_mode", type=str, default="dev") - args_parser.add_argument("--db_root_path", type=str, default="../../dbgpt_hub/data/bird/dev/dev_databases/") + args_parser.add_argument( + "--db_root_path", + type=str, + default="../../dbgpt_hub/data/bird/dev/dev_databases/", + ) args_parser.add_argument("--num_cpus", type=int, default=1) args_parser.add_argument("--meta_time_out", type=float, default=30.0) args_parser.add_argument("--mode_gt", type=str, default="gt") args_parser.add_argument("--mode_predict", type=str, default="gpt") args_parser.add_argument("--difficulty", type=str, default="simple") args_parser.add_argument("--diff_json_path", type=str, default="") - args_parser.add_argument("--etype", dest="etype", type=str, default="match", choices=("all", "exec", "match", "ves"),) + args_parser.add_argument( + "--etype", + dest="etype", + type=str, + default="match", + choices=("all", "exec", "match", "ves"), + ) args = args_parser.parse_args() exec_result = [] - pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, - data_mode=args.data_mode) + pred_queries, db_paths = package_sqls( + args.predicted_sql_path, + args.db_root_path, + mode=args.mode_predict, + data_mode=args.data_mode, + ) # generate gt sqls: - gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode="gt", - data_mode=args.data_mode) + gt_queries, db_paths_gt = package_sqls( + args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode + ) if len(db_paths) == 0: db_paths = db_paths_gt query_pairs = list(zip(pred_queries, gt_queries)) if args.etype in ["all", "exec", "ves"]: - run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) + run_sqls_parallel( + query_pairs, + db_places=db_paths, + num_cpus=args.num_cpus, + meta_time_out=args.meta_time_out, + ) else: for i, sql_pair in enumerate(query_pairs): predicted_sql, ground_truth = sql_pair - exec_result.append({"sql_idx": i, "match": int(predicted_sql == ground_truth)}) + exec_result.append( + {"sql_idx": i, "match": int(predicted_sql == ground_truth)} + ) exec_result = sort_results(exec_result) print("start calculate") if args.etype in ["all", "exec"]: - simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ - compute_acc_by_diff(exec_result, args.diff_json_path, "res") + ( + simple_acc, + moderate_acc, + challenging_acc, + acc, + count_lists, + ) = compute_acc_by_diff(exec_result, args.diff_json_path, "res") score_lists = [simple_acc, moderate_acc, challenging_acc, acc] print_data(score_lists, count_lists, metric="Exec Accuracy") if args.etype in ["all", "match"]: - simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ - compute_acc_by_diff(exec_result, args.diff_json_path, "match") + ( + simple_acc, + moderate_acc, + challenging_acc, + acc, + count_lists, + ) = compute_acc_by_diff(exec_result, args.diff_json_path, "match") score_lists = [simple_acc, moderate_acc, challenging_acc, acc] print_data(score_lists, count_lists, metric="Match Accuracy") if args.etype in ["all", "ves"]: - simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ - compute_acc_by_diff(exec_result, args.diff_json_path, "time_ratio") + ( + simple_acc, + moderate_acc, + challenging_acc, + acc, + count_lists, + ) = compute_acc_by_diff(exec_result, args.diff_json_path, "time_ratio") score_lists = [simple_acc, moderate_acc, challenging_acc, acc] print_data(score_lists, count_lists, metric="Ves") print(