diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index c2864aa..db2899b 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -24,7 +24,12 @@ def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None: self.num_shot = num_shot def decode_json_file( - self, data_file_list, table_file, db_id_name, output_name, is_multiple_turn=False + self, + data_file_list, + table_file, + db_id_name, + output_name, + is_multiple_turn=False, ): """ TO DO: @@ -68,9 +73,9 @@ def decode_json_file( if type(primary_key[j]) == int: if coloumns[primary_key[j] - 1][0] == i: source += ( - coloumns[primary_key[j] - 1][1] - + " is the primary key." - + "\n" + coloumns[primary_key[j] - 1][1] + + " is the primary key." + + "\n" ) # combination primary key elif type(primary_key[j]) == list: @@ -80,10 +85,10 @@ def decode_json_file( if coloumns[primary_key[j][k] - 1][0] == i: keys.append(coloumns[primary_key[j][k] - 1][1]) source += ( - combine_p + - ", ".join(keys) - + ") are the primary key." - + "\n" + combine_p + + ", ".join(keys) + + ") are the primary key." + + "\n" ) else: print("not support type", type(primary_key[j])) @@ -156,7 +161,9 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=train_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["train_tables_file"] + DATA_PATH, + data_info["data_source"], + data_info["train_tables_file"], ), db_id_name=data_info["db_id_name"], output_name=data_info["output_name"], @@ -172,7 +179,9 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=dev_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["dev_tables_file"] + DATA_PATH, + data_info["data_source"], + data_info["dev_tables_file"], ), db_id_name=data_info["db_id_name"], output_name=data_info["output_name"], diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py index 0c9b4ba..b84744a 100644 --- a/dbgpt_hub/eval/evaluation_bird.py +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -12,7 +12,7 @@ def load_json(dir): - with open(dir, 'r') as j: + with open(dir, "r") as j: contents = json.loads(j.read()) return contents @@ -37,27 +37,28 @@ def execute_sql(predicted_sql, ground_truth, db_path): def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): try: - res = func_timeout(meta_time_out, execute_sql, - args=(predicted_sql, ground_truth, db_place)) + res = func_timeout( + meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place) + ) except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: - result = [(f'timeout',)] + result = [(f"timeout",)] res = 0 except Exception as e: - result = [(f'error',)] # possibly len(query) > 512 or not executable + result = [(f"error",)] # possibly len(query) > 512 or not executable res = 0 # print(result) # result = str(set([ret[0] for ret in result])) - result = {'sql_idx': idx, 'res': res} + result = {"sql_idx": idx, "res": res} # print(result) return result -def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): +def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): clean_sqls = [] db_path_list = [] - if mode == 'gpt': + if mode == "gpt": # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) # for idx, sql_str in sql_data.items(): # if type(sql_str) == str: @@ -74,14 +75,14 @@ def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): # sql, db_name = l.split('\t') clean_sqls.append(l.strip()) # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') - elif mode == 'gt': + elif mode == "gt": sqls = open(sql_path) sql_txt = sqls.readlines() # sql_txt = [sql.split('\t')[0] for sql in sql_txt] for idx, sql_str in enumerate(sql_txt): - sql, db_name = sql_str.strip().split('\t') + sql, db_name = sql_str.strip().split("\t") clean_sqls.append(sql) - db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") return clean_sqls, db_path_list @@ -90,81 +91,116 @@ def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): pool = mp.Pool(processes=num_cpus) for i, sql_pair in enumerate(sqls): predicted_sql, ground_truth = sql_pair - pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), - callback=result_callback) + pool.apply_async( + execute_model, + args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), + callback=result_callback, + ) pool.close() pool.join() def sort_results(list_of_dicts): - return sorted(list_of_dicts, key=lambda x: x['sql_idx']) + return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) def compute_acc_by_diff(exec_results, diff_json_path): num_queries = len(exec_results) - results = [res['res'] for res in exec_results] + results = [res["res"] for res in exec_results] contents = load_json(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] for i, content in enumerate(contents): - if content['difficulty'] == 'simple': + if content["difficulty"] == "simple": simple_results.append(exec_results[i]) - if content['difficulty'] == 'moderate': + if content["difficulty"] == "moderate": moderate_results.append(exec_results[i]) - if content['difficulty'] == 'challenging': + if content["difficulty"] == "challenging": challenging_results.append(exec_results[i]) - simple_acc = sum([res['res'] for res in simple_results]) / len(simple_results) - moderate_acc = sum([res['res'] for res in moderate_results]) / len(moderate_results) - challenging_acc = sum([res['res'] for res in challenging_results]) / len(challenging_results) + simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) + moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) + challenging_acc = sum([res["res"] for res in challenging_results]) / len( + challenging_results + ) all_acc = sum(results) / num_queries - count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] - return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists + count_lists = [ + len(simple_results), + len(moderate_results), + len(challenging_results), + num_queries, + ] + return ( + simple_acc * 100, + moderate_acc * 100, + challenging_acc * 100, + all_acc * 100, + count_lists, + ) def print_data(score_lists, count_lists): - levels = ['simple', 'moderate', 'challenging', 'total'] + levels = ["simple", "moderate", "challenging", "total"] print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) - print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) - print('====================================== ACCURACY =====================================') - print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) + print( + "====================================== ACCURACY =====================================" + ) + print( + "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) + ) -if __name__ == '__main__': +if __name__ == "__main__": args_parser = argparse.ArgumentParser() - args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') - args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') - args_parser.add_argument('--data_mode', type=str, required=True, default='dev') - args_parser.add_argument('--db_root_path', type=str, required=True, default='') - args_parser.add_argument('--num_cpus', type=int, default=1) - args_parser.add_argument('--meta_time_out', type=float, default=30.0) - args_parser.add_argument('--mode_gt', type=str, default='gt') - args_parser.add_argument('--mode_predict', type=str, default='gpt') - args_parser.add_argument('--difficulty', type=str, default='simple') - args_parser.add_argument('--diff_json_path', type=str, default='') + args_parser.add_argument( + "--predicted_sql_path", type=str, required=True, default="" + ) + args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") + args_parser.add_argument("--data_mode", type=str, required=True, default="dev") + args_parser.add_argument("--db_root_path", type=str, required=True, default="") + args_parser.add_argument("--num_cpus", type=int, default=1) + args_parser.add_argument("--meta_time_out", type=float, default=30.0) + args_parser.add_argument("--mode_gt", type=str, default="gt") + args_parser.add_argument("--mode_predict", type=str, default="gpt") + args_parser.add_argument("--difficulty", type=str, default="simple") + args_parser.add_argument("--diff_json_path", type=str, default="") args = args_parser.parse_args() exec_result = [] - pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, - data_mode=args.data_mode) + pred_queries, db_paths = package_sqls( + args.predicted_sql_path, + args.db_root_path, + mode=args.mode_predict, + data_mode=args.data_mode, + ) # generate gt sqls: - gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', - data_mode=args.data_mode) + gt_queries, db_paths_gt = package_sqls( + args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode + ) if len(db_paths) == 0: db_paths = db_paths_gt query_pairs = list(zip(pred_queries, gt_queries)) - run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) + run_sqls_parallel( + query_pairs, + db_places=db_paths, + num_cpus=args.num_cpus, + meta_time_out=args.meta_time_out, + ) exec_result = sort_results(exec_result) - print('start calculate') - simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ - compute_acc_by_diff(exec_result, args.diff_json_path) + print("start calculate") + simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( + exec_result, args.diff_json_path + ) score_lists = [simple_acc, moderate_acc, challenging_acc, acc] print_data(score_lists, count_lists) - print('===========================================================================================') + print( + "===========================================================================================" + ) print("Finished evaluation")