From a60d828b4fec48db41f6e5b82300ec59158e6210 Mon Sep 17 00:00:00 2001 From: moutozf <3150102181@zju.edu.cn> Date: Wed, 20 Dec 2023 20:51:28 +0800 Subject: [PATCH 1/4] BIRD dataset training, prediction and evaluation scripts --- dbgpt_hub/configs/config.py | 21 ++- dbgpt_hub/data_process/sql_data_process.py | 36 +++-- dbgpt_hub/eval/evaluation_bird.py | 170 +++++++++++++++++++++ 3 files changed, 213 insertions(+), 14 deletions(-) create mode 100644 dbgpt_hub/eval/evaluation_bird.py diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index 04341a4..bc18eea 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -47,13 +47,24 @@ # TODO: BIRD \ WiKiSQL \ ... SQL_DATA_INFO = [ { - "data_source": "spider", - "train_file": ["train_spider.json", "train_others.json"], - "dev_file": ["dev.json"], - "tables_file": "tables.json", + "data_source": "bird", + "train_file": ["train/train.json"], + "dev_file": ["dev/dev.json"], + "train_tables_file": "train/train_tables.json", + "dev_tables_file": "dev/dev_tables.json", "db_id_name": "db_id", + "output_name": "SQL", "is_multiple_turn": False, - } + }, + # { + # "data_source": "spider", + # "train_file": ["train_spider.json", "train_others.json"], + # "dev_file": ["dev.json"], + # "tables_file": "tables.json", + # "db_id_name": "db_id", + # "output_name": "query", + # "is_multiple_turn": False, + # } # , # { # "data_source": "chase", diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index a3e30c7..7abe597 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -22,7 +22,7 @@ def __init__(self, train_file=None, dev_file=None) -> None: self.dev_file = dev_file def decode_json_file( - self, data_file_list, table_file, db_id_name, is_multiple_turn=False + self, data_file_list, table_file, db_id_name, output_name, is_multiple_turn=False ): """ TO DO: @@ -63,12 +63,28 @@ def decode_json_file( # get primary key info for j in range(len(primary_key)): - if coloumns[primary_key[j] - 1][0] == i: + if type(primary_key[j]) == int: + if coloumns[primary_key[j] - 1][0] == i: + source += ( + coloumns[primary_key[j] - 1][1] + + " is the primary key." + + "\n" + ) + elif type(primary_key[j]) == list: + combine_p = "The combination of (" + keys = [] + for k in range(len(primary_key[j])): + if coloumns[primary_key[j][k] - 1][0] == i: + keys.append(coloumns[primary_key[j][k] - 1][1]) source += ( - coloumns[primary_key[j] - 1][1] - + " is the primary key." - + "\n" + combine_p + + ", ".join(keys) + + ") are the primary key." + + "\n" ) + else: + print("not support type", type(primary_key[j])) + continue # get foreign key info for key in foreign_keys: @@ -98,7 +114,7 @@ def decode_json_file( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(interaction["utterance"]), - "output": interaction["query"], + "output": interaction[output_name], "history": history, } res.append(input) @@ -115,7 +131,7 @@ def decode_json_file( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(data["question"]), - "output": data["query"], + "output": data[output_name], "history": [], } res.append(input) @@ -133,9 +149,10 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=train_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["tables_file"] + DATA_PATH, data_info["data_source"], data_info["train_tables_file"] ), db_id_name=data_info["db_id_name"], + output_name=data_info["output_name"], is_multiple_turn=data_info["is_multiple_turn"], ) ) @@ -148,9 +165,10 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=dev_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["tables_file"] + DATA_PATH, data_info["data_source"], data_info["dev_tables_file"] ), db_id_name=data_info["db_id_name"], + output_name=data_info["output_name"], is_multiple_turn=data_info["is_multiple_turn"], ) ) diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py new file mode 100644 index 0000000..0c9b4ba --- /dev/null +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -0,0 +1,170 @@ +""" +do evaluate about the predict sql in dataset BIRD,compare with default dev.sql +--db +""" + +import sys +import json +import argparse +import sqlite3 +import multiprocessing as mp +from func_timeout import func_timeout, FunctionTimedOut + + +def load_json(dir): + with open(dir, 'r') as j: + contents = json.loads(j.read()) + return contents + + +def result_callback(result): + exec_result.append(result) + + +def execute_sql(predicted_sql, ground_truth, db_path): + conn = sqlite3.connect(db_path) + # Connect to the database + cursor = conn.cursor() + cursor.execute(predicted_sql) + predicted_res = cursor.fetchall() + cursor.execute(ground_truth) + ground_truth_res = cursor.fetchall() + res = 0 + if set(predicted_res) == set(ground_truth_res): + res = 1 + return res + + +def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): + try: + res = func_timeout(meta_time_out, execute_sql, + args=(predicted_sql, ground_truth, db_place)) + except KeyboardInterrupt: + sys.exit(0) + except FunctionTimedOut: + result = [(f'timeout',)] + res = 0 + except Exception as e: + result = [(f'error',)] # possibly len(query) > 512 or not executable + res = 0 + # print(result) + # result = str(set([ret[0] for ret in result])) + result = {'sql_idx': idx, 'res': res} + # print(result) + return result + + +def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): + clean_sqls = [] + db_path_list = [] + if mode == 'gpt': + # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) + # for idx, sql_str in sql_data.items(): + # if type(sql_str) == str: + # sql, db_name = sql_str.split('\t----- bird -----\t') + # else: + # sql, db_name = " ", "financial" + # clean_sqls.append(sql) + # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + with open(sql_path) as f: + for l in f.readlines(): + # if len(l.strip()) == 0: + # sql, db_name = " ", "financial" + # else: + # sql, db_name = l.split('\t') + clean_sqls.append(l.strip()) + # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + elif mode == 'gt': + sqls = open(sql_path) + sql_txt = sqls.readlines() + # sql_txt = [sql.split('\t')[0] for sql in sql_txt] + for idx, sql_str in enumerate(sql_txt): + sql, db_name = sql_str.strip().split('\t') + clean_sqls.append(sql) + db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + + return clean_sqls, db_path_list + + +def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): + pool = mp.Pool(processes=num_cpus) + for i, sql_pair in enumerate(sqls): + predicted_sql, ground_truth = sql_pair + pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), + callback=result_callback) + pool.close() + pool.join() + + +def sort_results(list_of_dicts): + return sorted(list_of_dicts, key=lambda x: x['sql_idx']) + + +def compute_acc_by_diff(exec_results, diff_json_path): + num_queries = len(exec_results) + results = [res['res'] for res in exec_results] + contents = load_json(diff_json_path) + simple_results, moderate_results, challenging_results = [], [], [] + + for i, content in enumerate(contents): + if content['difficulty'] == 'simple': + simple_results.append(exec_results[i]) + + if content['difficulty'] == 'moderate': + moderate_results.append(exec_results[i]) + + if content['difficulty'] == 'challenging': + challenging_results.append(exec_results[i]) + + simple_acc = sum([res['res'] for res in simple_results]) / len(simple_results) + moderate_acc = sum([res['res'] for res in moderate_results]) / len(moderate_results) + challenging_acc = sum([res['res'] for res in challenging_results]) / len(challenging_results) + all_acc = sum(results) / num_queries + count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] + return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists + + +def print_data(score_lists, count_lists): + levels = ['simple', 'moderate', 'challenging', 'total'] + print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) + + print('====================================== ACCURACY =====================================') + print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) + + +if __name__ == '__main__': + args_parser = argparse.ArgumentParser() + args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') + args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') + args_parser.add_argument('--data_mode', type=str, required=True, default='dev') + args_parser.add_argument('--db_root_path', type=str, required=True, default='') + args_parser.add_argument('--num_cpus', type=int, default=1) + args_parser.add_argument('--meta_time_out', type=float, default=30.0) + args_parser.add_argument('--mode_gt', type=str, default='gt') + args_parser.add_argument('--mode_predict', type=str, default='gpt') + args_parser.add_argument('--difficulty', type=str, default='simple') + args_parser.add_argument('--diff_json_path', type=str, default='') + args = args_parser.parse_args() + exec_result = [] + + pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, + data_mode=args.data_mode) + # generate gt sqls: + gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', + data_mode=args.data_mode) + + if len(db_paths) == 0: + db_paths = db_paths_gt + + query_pairs = list(zip(pred_queries, gt_queries)) + run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) + exec_result = sort_results(exec_result) + + print('start calculate') + simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ + compute_acc_by_diff(exec_result, args.diff_json_path) + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists) + print('===========================================================================================') + print("Finished evaluation") From 8ce58b31e05d09fec7d8c5617196db7512f00c49 Mon Sep 17 00:00:00 2001 From: moutozf Date: Fri, 29 Dec 2023 17:08:05 +0800 Subject: [PATCH 2/4] add spider config and comment --- dbgpt_hub/configs/config.py | 5 +++-- dbgpt_hub/data_process/sql_data_process.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index bc18eea..0c485f2 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -55,12 +55,13 @@ "db_id_name": "db_id", "output_name": "SQL", "is_multiple_turn": False, - }, + } # { # "data_source": "spider", # "train_file": ["train_spider.json", "train_others.json"], # "dev_file": ["dev.json"], - # "tables_file": "tables.json", + # "train_tables_file": "tables.json", + # "dev_tables_file": "tables.json", # "db_id_name": "db_id", # "output_name": "query", # "is_multiple_turn": False, diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index 7abe597..dc5da35 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -70,6 +70,7 @@ def decode_json_file( + " is the primary key." + "\n" ) + # combination primary key elif type(primary_key[j]) == list: combine_p = "The combination of (" keys = [] From 44d43fde75b4931edd3c4a4569488f1bef827616 Mon Sep 17 00:00:00 2001 From: qidanrui Date: Fri, 29 Dec 2023 11:09:33 +0000 Subject: [PATCH 3/4] reformat code style --- dbgpt_hub/data_process/sql_data_process.py | 29 +++-- dbgpt_hub/eval/evaluation_bird.py | 130 +++++++++++++-------- 2 files changed, 102 insertions(+), 57 deletions(-) diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index c2864aa..db2899b 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -24,7 +24,12 @@ def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None: self.num_shot = num_shot def decode_json_file( - self, data_file_list, table_file, db_id_name, output_name, is_multiple_turn=False + self, + data_file_list, + table_file, + db_id_name, + output_name, + is_multiple_turn=False, ): """ TO DO: @@ -68,9 +73,9 @@ def decode_json_file( if type(primary_key[j]) == int: if coloumns[primary_key[j] - 1][0] == i: source += ( - coloumns[primary_key[j] - 1][1] - + " is the primary key." - + "\n" + coloumns[primary_key[j] - 1][1] + + " is the primary key." + + "\n" ) # combination primary key elif type(primary_key[j]) == list: @@ -80,10 +85,10 @@ def decode_json_file( if coloumns[primary_key[j][k] - 1][0] == i: keys.append(coloumns[primary_key[j][k] - 1][1]) source += ( - combine_p + - ", ".join(keys) - + ") are the primary key." - + "\n" + combine_p + + ", ".join(keys) + + ") are the primary key." + + "\n" ) else: print("not support type", type(primary_key[j])) @@ -156,7 +161,9 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=train_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["train_tables_file"] + DATA_PATH, + data_info["data_source"], + data_info["train_tables_file"], ), db_id_name=data_info["db_id_name"], output_name=data_info["output_name"], @@ -172,7 +179,9 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=dev_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["dev_tables_file"] + DATA_PATH, + data_info["data_source"], + data_info["dev_tables_file"], ), db_id_name=data_info["db_id_name"], output_name=data_info["output_name"], diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py index 0c9b4ba..b84744a 100644 --- a/dbgpt_hub/eval/evaluation_bird.py +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -12,7 +12,7 @@ def load_json(dir): - with open(dir, 'r') as j: + with open(dir, "r") as j: contents = json.loads(j.read()) return contents @@ -37,27 +37,28 @@ def execute_sql(predicted_sql, ground_truth, db_path): def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): try: - res = func_timeout(meta_time_out, execute_sql, - args=(predicted_sql, ground_truth, db_place)) + res = func_timeout( + meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place) + ) except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: - result = [(f'timeout',)] + result = [(f"timeout",)] res = 0 except Exception as e: - result = [(f'error',)] # possibly len(query) > 512 or not executable + result = [(f"error",)] # possibly len(query) > 512 or not executable res = 0 # print(result) # result = str(set([ret[0] for ret in result])) - result = {'sql_idx': idx, 'res': res} + result = {"sql_idx": idx, "res": res} # print(result) return result -def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): +def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): clean_sqls = [] db_path_list = [] - if mode == 'gpt': + if mode == "gpt": # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) # for idx, sql_str in sql_data.items(): # if type(sql_str) == str: @@ -74,14 +75,14 @@ def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): # sql, db_name = l.split('\t') clean_sqls.append(l.strip()) # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') - elif mode == 'gt': + elif mode == "gt": sqls = open(sql_path) sql_txt = sqls.readlines() # sql_txt = [sql.split('\t')[0] for sql in sql_txt] for idx, sql_str in enumerate(sql_txt): - sql, db_name = sql_str.strip().split('\t') + sql, db_name = sql_str.strip().split("\t") clean_sqls.append(sql) - db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") return clean_sqls, db_path_list @@ -90,81 +91,116 @@ def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): pool = mp.Pool(processes=num_cpus) for i, sql_pair in enumerate(sqls): predicted_sql, ground_truth = sql_pair - pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), - callback=result_callback) + pool.apply_async( + execute_model, + args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), + callback=result_callback, + ) pool.close() pool.join() def sort_results(list_of_dicts): - return sorted(list_of_dicts, key=lambda x: x['sql_idx']) + return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) def compute_acc_by_diff(exec_results, diff_json_path): num_queries = len(exec_results) - results = [res['res'] for res in exec_results] + results = [res["res"] for res in exec_results] contents = load_json(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] for i, content in enumerate(contents): - if content['difficulty'] == 'simple': + if content["difficulty"] == "simple": simple_results.append(exec_results[i]) - if content['difficulty'] == 'moderate': + if content["difficulty"] == "moderate": moderate_results.append(exec_results[i]) - if content['difficulty'] == 'challenging': + if content["difficulty"] == "challenging": challenging_results.append(exec_results[i]) - simple_acc = sum([res['res'] for res in simple_results]) / len(simple_results) - moderate_acc = sum([res['res'] for res in moderate_results]) / len(moderate_results) - challenging_acc = sum([res['res'] for res in challenging_results]) / len(challenging_results) + simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) + moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) + challenging_acc = sum([res["res"] for res in challenging_results]) / len( + challenging_results + ) all_acc = sum(results) / num_queries - count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] - return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists + count_lists = [ + len(simple_results), + len(moderate_results), + len(challenging_results), + num_queries, + ] + return ( + simple_acc * 100, + moderate_acc * 100, + challenging_acc * 100, + all_acc * 100, + count_lists, + ) def print_data(score_lists, count_lists): - levels = ['simple', 'moderate', 'challenging', 'total'] + levels = ["simple", "moderate", "challenging", "total"] print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) - print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) - print('====================================== ACCURACY =====================================') - print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) + print( + "====================================== ACCURACY =====================================" + ) + print( + "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) + ) -if __name__ == '__main__': +if __name__ == "__main__": args_parser = argparse.ArgumentParser() - args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') - args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') - args_parser.add_argument('--data_mode', type=str, required=True, default='dev') - args_parser.add_argument('--db_root_path', type=str, required=True, default='') - args_parser.add_argument('--num_cpus', type=int, default=1) - args_parser.add_argument('--meta_time_out', type=float, default=30.0) - args_parser.add_argument('--mode_gt', type=str, default='gt') - args_parser.add_argument('--mode_predict', type=str, default='gpt') - args_parser.add_argument('--difficulty', type=str, default='simple') - args_parser.add_argument('--diff_json_path', type=str, default='') + args_parser.add_argument( + "--predicted_sql_path", type=str, required=True, default="" + ) + args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") + args_parser.add_argument("--data_mode", type=str, required=True, default="dev") + args_parser.add_argument("--db_root_path", type=str, required=True, default="") + args_parser.add_argument("--num_cpus", type=int, default=1) + args_parser.add_argument("--meta_time_out", type=float, default=30.0) + args_parser.add_argument("--mode_gt", type=str, default="gt") + args_parser.add_argument("--mode_predict", type=str, default="gpt") + args_parser.add_argument("--difficulty", type=str, default="simple") + args_parser.add_argument("--diff_json_path", type=str, default="") args = args_parser.parse_args() exec_result = [] - pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, - data_mode=args.data_mode) + pred_queries, db_paths = package_sqls( + args.predicted_sql_path, + args.db_root_path, + mode=args.mode_predict, + data_mode=args.data_mode, + ) # generate gt sqls: - gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', - data_mode=args.data_mode) + gt_queries, db_paths_gt = package_sqls( + args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode + ) if len(db_paths) == 0: db_paths = db_paths_gt query_pairs = list(zip(pred_queries, gt_queries)) - run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) + run_sqls_parallel( + query_pairs, + db_places=db_paths, + num_cpus=args.num_cpus, + meta_time_out=args.meta_time_out, + ) exec_result = sort_results(exec_result) - print('start calculate') - simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ - compute_acc_by_diff(exec_result, args.diff_json_path) + print("start calculate") + simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( + exec_result, args.diff_json_path + ) score_lists = [simple_acc, moderate_acc, challenging_acc, acc] print_data(score_lists, count_lists) - print('===========================================================================================') + print( + "===========================================================================================" + ) print("Finished evaluation") From bab9c3b28f146f463e6a601357aaa66894549792 Mon Sep 17 00:00:00 2001 From: moutozf Date: Mon, 1 Jan 2024 22:57:29 +0800 Subject: [PATCH 4/4] change dataset to spider and fix output name in multi turn --- dbgpt_hub/configs/config.py | 24 +++++++++++----------- dbgpt_hub/data_process/sql_data_process.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index f316410..268d6d6 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -47,23 +47,23 @@ # TODO: BIRD \ WiKiSQL \ ... SQL_DATA_INFO = [ { - "data_source": "bird", - "train_file": ["train/train.json"], - "dev_file": ["dev/dev.json"], - "train_tables_file": "train/train_tables.json", - "dev_tables_file": "dev/dev_tables.json", + "data_source": "spider", + "train_file": ["train_spider.json", "train_others.json"], + "dev_file": ["dev.json"], + "train_tables_file": "tables.json", + "dev_tables_file": "tables.json", "db_id_name": "db_id", - "output_name": "SQL", + "output_name": "query", "is_multiple_turn": False, } # { - # "data_source": "spider", - # "train_file": ["train_spider.json", "train_others.json"], - # "dev_file": ["dev.json"], - # "train_tables_file": "tables.json", - # "dev_tables_file": "tables.json", + # "data_source": "bird", + # "train_file": ["train/train.json"], + # "dev_file": ["dev/dev.json"], + # "train_tables_file": "train/train_tables.json", + # "dev_tables_file": "dev/dev_tables.json", # "db_id_name": "db_id", - # "output_name": "query", + # "output_name": "SQL", # "is_multiple_turn": False, # } # , diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index db2899b..33e1883 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -133,7 +133,7 @@ def decode_json_file( history.append( ( INPUT_PROMPT.format(interaction["utterance"]), - interaction["query"], + interaction[output_name], ) ) else: # 单轮