diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index f5b2563..268d6d6 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -50,10 +50,22 @@ "data_source": "spider", "train_file": ["train_spider.json", "train_others.json"], "dev_file": ["dev.json"], - "tables_file": "tables.json", + "train_tables_file": "tables.json", + "dev_tables_file": "tables.json", "db_id_name": "db_id", + "output_name": "query", "is_multiple_turn": False, } + # { + # "data_source": "bird", + # "train_file": ["train/train.json"], + # "dev_file": ["dev/dev.json"], + # "train_tables_file": "train/train_tables.json", + # "dev_tables_file": "dev/dev_tables.json", + # "db_id_name": "db_id", + # "output_name": "SQL", + # "is_multiple_turn": False, + # } # , # { # "data_source": "chase", diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index e869f2e..33e1883 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -24,7 +24,12 @@ def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None: self.num_shot = num_shot def decode_json_file( - self, data_file_list, table_file, db_id_name, is_multiple_turn=False + self, + data_file_list, + table_file, + db_id_name, + output_name, + is_multiple_turn=False, ): """ TO DO: @@ -65,12 +70,29 @@ def decode_json_file( # get primary key info for j in range(len(primary_key)): - if coloumns[primary_key[j] - 1][0] == i: + if type(primary_key[j]) == int: + if coloumns[primary_key[j] - 1][0] == i: + source += ( + coloumns[primary_key[j] - 1][1] + + " is the primary key." + + "\n" + ) + # combination primary key + elif type(primary_key[j]) == list: + combine_p = "The combination of (" + keys = [] + for k in range(len(primary_key[j])): + if coloumns[primary_key[j][k] - 1][0] == i: + keys.append(coloumns[primary_key[j][k] - 1][1]) source += ( - coloumns[primary_key[j] - 1][1] - + " is the primary key." + combine_p + + ", ".join(keys) + + ") are the primary key." + "\n" ) + else: + print("not support type", type(primary_key[j])) + continue # get foreign key info for key in foreign_keys: @@ -104,14 +126,14 @@ def decode_json_file( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(interaction["utterance"]), - "output": interaction["query"], + "output": interaction[output_name], "history": history, } res.append(input) history.append( ( INPUT_PROMPT.format(interaction["utterance"]), - interaction["query"], + interaction[output_name], ) ) else: # 单轮 @@ -121,7 +143,7 @@ def decode_json_file( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(data["question"]), - "output": data["query"], + "output": data[output_name], "history": [], } res.append(input) @@ -139,9 +161,12 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=train_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["tables_file"] + DATA_PATH, + data_info["data_source"], + data_info["train_tables_file"], ), db_id_name=data_info["db_id_name"], + output_name=data_info["output_name"], is_multiple_turn=data_info["is_multiple_turn"], ) ) @@ -154,9 +179,12 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=dev_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["tables_file"] + DATA_PATH, + data_info["data_source"], + data_info["dev_tables_file"], ), db_id_name=data_info["db_id_name"], + output_name=data_info["output_name"], is_multiple_turn=data_info["is_multiple_turn"], ) ) diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py new file mode 100644 index 0000000..b84744a --- /dev/null +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -0,0 +1,206 @@ +""" +do evaluate about the predict sql in dataset BIRD,compare with default dev.sql +--db +""" + +import sys +import json +import argparse +import sqlite3 +import multiprocessing as mp +from func_timeout import func_timeout, FunctionTimedOut + + +def load_json(dir): + with open(dir, "r") as j: + contents = json.loads(j.read()) + return contents + + +def result_callback(result): + exec_result.append(result) + + +def execute_sql(predicted_sql, ground_truth, db_path): + conn = sqlite3.connect(db_path) + # Connect to the database + cursor = conn.cursor() + cursor.execute(predicted_sql) + predicted_res = cursor.fetchall() + cursor.execute(ground_truth) + ground_truth_res = cursor.fetchall() + res = 0 + if set(predicted_res) == set(ground_truth_res): + res = 1 + return res + + +def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): + try: + res = func_timeout( + meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place) + ) + except KeyboardInterrupt: + sys.exit(0) + except FunctionTimedOut: + result = [(f"timeout",)] + res = 0 + except Exception as e: + result = [(f"error",)] # possibly len(query) > 512 or not executable + res = 0 + # print(result) + # result = str(set([ret[0] for ret in result])) + result = {"sql_idx": idx, "res": res} + # print(result) + return result + + +def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): + clean_sqls = [] + db_path_list = [] + if mode == "gpt": + # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) + # for idx, sql_str in sql_data.items(): + # if type(sql_str) == str: + # sql, db_name = sql_str.split('\t----- bird -----\t') + # else: + # sql, db_name = " ", "financial" + # clean_sqls.append(sql) + # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + with open(sql_path) as f: + for l in f.readlines(): + # if len(l.strip()) == 0: + # sql, db_name = " ", "financial" + # else: + # sql, db_name = l.split('\t') + clean_sqls.append(l.strip()) + # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + elif mode == "gt": + sqls = open(sql_path) + sql_txt = sqls.readlines() + # sql_txt = [sql.split('\t')[0] for sql in sql_txt] + for idx, sql_str in enumerate(sql_txt): + sql, db_name = sql_str.strip().split("\t") + clean_sqls.append(sql) + db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") + + return clean_sqls, db_path_list + + +def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): + pool = mp.Pool(processes=num_cpus) + for i, sql_pair in enumerate(sqls): + predicted_sql, ground_truth = sql_pair + pool.apply_async( + execute_model, + args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), + callback=result_callback, + ) + pool.close() + pool.join() + + +def sort_results(list_of_dicts): + return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) + + +def compute_acc_by_diff(exec_results, diff_json_path): + num_queries = len(exec_results) + results = [res["res"] for res in exec_results] + contents = load_json(diff_json_path) + simple_results, moderate_results, challenging_results = [], [], [] + + for i, content in enumerate(contents): + if content["difficulty"] == "simple": + simple_results.append(exec_results[i]) + + if content["difficulty"] == "moderate": + moderate_results.append(exec_results[i]) + + if content["difficulty"] == "challenging": + challenging_results.append(exec_results[i]) + + simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) + moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) + challenging_acc = sum([res["res"] for res in challenging_results]) / len( + challenging_results + ) + all_acc = sum(results) / num_queries + count_lists = [ + len(simple_results), + len(moderate_results), + len(challenging_results), + num_queries, + ] + return ( + simple_acc * 100, + moderate_acc * 100, + challenging_acc * 100, + all_acc * 100, + count_lists, + ) + + +def print_data(score_lists, count_lists): + levels = ["simple", "moderate", "challenging", "total"] + print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) + + print( + "====================================== ACCURACY =====================================" + ) + print( + "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) + ) + + +if __name__ == "__main__": + args_parser = argparse.ArgumentParser() + args_parser.add_argument( + "--predicted_sql_path", type=str, required=True, default="" + ) + args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") + args_parser.add_argument("--data_mode", type=str, required=True, default="dev") + args_parser.add_argument("--db_root_path", type=str, required=True, default="") + args_parser.add_argument("--num_cpus", type=int, default=1) + args_parser.add_argument("--meta_time_out", type=float, default=30.0) + args_parser.add_argument("--mode_gt", type=str, default="gt") + args_parser.add_argument("--mode_predict", type=str, default="gpt") + args_parser.add_argument("--difficulty", type=str, default="simple") + args_parser.add_argument("--diff_json_path", type=str, default="") + args = args_parser.parse_args() + exec_result = [] + + pred_queries, db_paths = package_sqls( + args.predicted_sql_path, + args.db_root_path, + mode=args.mode_predict, + data_mode=args.data_mode, + ) + # generate gt sqls: + gt_queries, db_paths_gt = package_sqls( + args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode + ) + + if len(db_paths) == 0: + db_paths = db_paths_gt + + query_pairs = list(zip(pred_queries, gt_queries)) + run_sqls_parallel( + query_pairs, + db_places=db_paths, + num_cpus=args.num_cpus, + meta_time_out=args.meta_time_out, + ) + exec_result = sort_results(exec_result) + + print("start calculate") + simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( + exec_result, args.diff_json_path + ) + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists) + print( + "===========================================================================================" + ) + print("Finished evaluation")