diff --git a/assets/wechat.JPG b/assets/wechat.JPG index 3cde49e..af45cab 100644 Binary files a/assets/wechat.JPG and b/assets/wechat.JPG differ diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index 0c485f2..268d6d6 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -47,23 +47,23 @@ # TODO: BIRD \ WiKiSQL \ ... SQL_DATA_INFO = [ { - "data_source": "bird", - "train_file": ["train/train.json"], - "dev_file": ["dev/dev.json"], - "train_tables_file": "train/train_tables.json", - "dev_tables_file": "dev/dev_tables.json", + "data_source": "spider", + "train_file": ["train_spider.json", "train_others.json"], + "dev_file": ["dev.json"], + "train_tables_file": "tables.json", + "dev_tables_file": "tables.json", "db_id_name": "db_id", - "output_name": "SQL", + "output_name": "query", "is_multiple_turn": False, } # { - # "data_source": "spider", - # "train_file": ["train_spider.json", "train_others.json"], - # "dev_file": ["dev.json"], - # "train_tables_file": "tables.json", - # "dev_tables_file": "tables.json", + # "data_source": "bird", + # "train_file": ["train/train.json"], + # "dev_file": ["dev/dev.json"], + # "train_tables_file": "train/train_tables.json", + # "dev_tables_file": "dev/dev_tables.json", # "db_id_name": "db_id", - # "output_name": "query", + # "output_name": "SQL", # "is_multiple_turn": False, # } # , @@ -101,6 +101,31 @@ ##Instruction:\n{}\n""" INPUT_PROMPT = "###Input:\n{}\n\n###Response:" +INSTRUCTION_ONE_SHOT_PROMPT = """\ +I want you to act as a SQL terminal in front of an example database. \ +You need only to return the sql command to me. \ +First, I will show you few examples of an instruction followed by the correct SQL response. \ +Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\ +\n### Example1 Instruction: +The database contains tables such as employee, salary, and position. \ +Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \ +Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \ +Table position has columns such as position_id, title, and department. position_id is the primary key. \ +The employee_id of salary is the foreign key of employee_id of employee. \ +The position_id of employee is the foreign key of position_id of position.\ +\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +\n###New Instruction:\n{}\n""" + +# EXAMPLES =[EXAMPLE1, EXAMPLE1] + +# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +# \n###New Instruction:\n{}\n" + +### test-------------------- + + # METHODS = ["full", "freeze", "lora"] # STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"] diff --git a/dbgpt_hub/data/dataset_info.json b/dbgpt_hub/data/dataset_info.json index f5988a1..49d14c8 100644 --- a/dbgpt_hub/data/dataset_info.json +++ b/dbgpt_hub/data/dataset_info.json @@ -8,6 +8,15 @@ "history": "history" } }, + "example_text2sql_train_one_shot": { + "file_name": "example_text2sql_train_one_shot.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output", + "history": "history" + } + }, "example_rm_train": { "file_name": "oaast_rm_zh.json", "columns": { diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index dc5da35..18399d5 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -13,16 +13,23 @@ DATA_PATH, INPUT_PROMPT, INSTRUCTION_PROMPT, + INSTRUCTION_ONE_SHOT_PROMPT, ) class ProcessSqlData: - def __init__(self, train_file=None, dev_file=None) -> None: + def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None: self.train_file = train_file self.dev_file = dev_file + self.num_shot = num_shot def decode_json_file( - self, data_file_list, table_file, db_id_name, output_name, is_multiple_turn=False + self, + data_file_list, + table_file, + db_id_name, + output_name, + is_multiple_turn=False, ): """ TO DO: @@ -66,9 +73,9 @@ def decode_json_file( if type(primary_key[j]) == int: if coloumns[primary_key[j] - 1][0] == i: source += ( - coloumns[primary_key[j] - 1][1] - + " is the primary key." - + "\n" + coloumns[primary_key[j] - 1][1] + + " is the primary key." + + "\n" ) # combination primary key elif type(primary_key[j]) == list: @@ -78,10 +85,10 @@ def decode_json_file( if coloumns[primary_key[j][k] - 1][0] == i: keys.append(coloumns[primary_key[j][k] - 1][1]) source += ( - combine_p + - ", ".join(keys) - + ") are the primary key." - + "\n" + combine_p + + ", ".join(keys) + + ") are the primary key." + + "\n" ) else: print("not support type", type(primary_key[j])) @@ -104,6 +111,10 @@ def decode_json_file( db_dict[item["db_id"]] = source res = [] + base_instruction = INSTRUCTION_PROMPT + if self.num_shot == 1: + base_instruction = INSTRUCTION_ONE_SHOT_PROMPT + for data in tqdm(datas): if data[db_id_name] in db_dict.keys(): if is_multiple_turn: # 多轮 @@ -111,7 +122,7 @@ def decode_json_file( for interaction in data["interaction"]: input = { "db_id": data[db_id_name], - "instruction": INSTRUCTION_PROMPT.format( + "instruction": base_instruction.format( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(interaction["utterance"]), @@ -128,7 +139,7 @@ def decode_json_file( else: # 单轮 input = { "db_id": data[db_id_name], - "instruction": INSTRUCTION_PROMPT.format( + "instruction": base_instruction.format( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(data["question"]), @@ -150,7 +161,9 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=train_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["train_tables_file"] + DATA_PATH, + data_info["data_source"], + data_info["train_tables_file"], ), db_id_name=data_info["db_id_name"], output_name=data_info["output_name"], @@ -166,7 +179,9 @@ def create_sft_raw_data(self): self.decode_json_file( data_file_list=dev_data_file_list, table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["dev_tables_file"] + DATA_PATH, + data_info["data_source"], + data_info["dev_tables_file"], ), db_id_name=data_info["db_id_name"], output_name=data_info["output_name"], @@ -186,3 +201,17 @@ def create_sft_raw_data(self): train_file=all_in_one_train_file, dev_file=all_in_one_dev_file ) precess.create_sft_raw_data() + + # one-shot + one_shot_all_in_one_train_file = os.path.join( + DATA_PATH, "example_text2sql_train_one_shot.json" + ) + one_shot_all_in_one_dev_file = os.path.join( + DATA_PATH, "example_text2sql_dev_one_shot.json" + ) + one_shot_precess = ProcessSqlData( + train_file=one_shot_all_in_one_train_file, + dev_file=one_shot_all_in_one_dev_file, + num_shot=1, + ) + one_shot_precess.create_sft_raw_data() \ No newline at end of file diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py index 0c9b4ba..649632b 100644 --- a/dbgpt_hub/eval/evaluation_bird.py +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -12,7 +12,7 @@ def load_json(dir): - with open(dir, 'r') as j: + with open(dir, "r") as j: contents = json.loads(j.read()) return contents @@ -37,27 +37,28 @@ def execute_sql(predicted_sql, ground_truth, db_path): def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): try: - res = func_timeout(meta_time_out, execute_sql, - args=(predicted_sql, ground_truth, db_place)) + res = func_timeout( + meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place) + ) except KeyboardInterrupt: sys.exit(0) except FunctionTimedOut: - result = [(f'timeout',)] + result = [(f"timeout",)] res = 0 except Exception as e: - result = [(f'error',)] # possibly len(query) > 512 or not executable + result = [(f"error",)] # possibly len(query) > 512 or not executable res = 0 # print(result) # result = str(set([ret[0] for ret in result])) - result = {'sql_idx': idx, 'res': res} + result = {"sql_idx": idx, "res": res} # print(result) return result -def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): +def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): clean_sqls = [] db_path_list = [] - if mode == 'gpt': + if mode == "gpt": # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) # for idx, sql_str in sql_data.items(): # if type(sql_str) == str: @@ -74,14 +75,14 @@ def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): # sql, db_name = l.split('\t') clean_sqls.append(l.strip()) # db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') - elif mode == 'gt': + elif mode == "gt": sqls = open(sql_path) sql_txt = sqls.readlines() # sql_txt = [sql.split('\t')[0] for sql in sql_txt] for idx, sql_str in enumerate(sql_txt): - sql, db_name = sql_str.strip().split('\t') + sql, db_name = sql_str.strip().split("\t") clean_sqls.append(sql) - db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") return clean_sqls, db_path_list @@ -90,81 +91,116 @@ def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): pool = mp.Pool(processes=num_cpus) for i, sql_pair in enumerate(sqls): predicted_sql, ground_truth = sql_pair - pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), - callback=result_callback) + pool.apply_async( + execute_model, + args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), + callback=result_callback, + ) pool.close() pool.join() def sort_results(list_of_dicts): - return sorted(list_of_dicts, key=lambda x: x['sql_idx']) + return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) def compute_acc_by_diff(exec_results, diff_json_path): num_queries = len(exec_results) - results = [res['res'] for res in exec_results] + results = [res["res"] for res in exec_results] contents = load_json(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] for i, content in enumerate(contents): - if content['difficulty'] == 'simple': + if content["difficulty"] == "simple": simple_results.append(exec_results[i]) - if content['difficulty'] == 'moderate': + if content["difficulty"] == "moderate": moderate_results.append(exec_results[i]) - if content['difficulty'] == 'challenging': + if content["difficulty"] == "challenging": challenging_results.append(exec_results[i]) - simple_acc = sum([res['res'] for res in simple_results]) / len(simple_results) - moderate_acc = sum([res['res'] for res in moderate_results]) / len(moderate_results) - challenging_acc = sum([res['res'] for res in challenging_results]) / len(challenging_results) + simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) + moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) + challenging_acc = sum([res["res"] for res in challenging_results]) / len( + challenging_results + ) all_acc = sum(results) / num_queries - count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] - return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists + count_lists = [ + len(simple_results), + len(moderate_results), + len(challenging_results), + num_queries, + ] + return ( + simple_acc * 100, + moderate_acc * 100, + challenging_acc * 100, + all_acc * 100, + count_lists, + ) def print_data(score_lists, count_lists): - levels = ['simple', 'moderate', 'challenging', 'total'] + levels = ["simple", "moderate", "challenging", "total"] print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) - print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) - print('====================================== ACCURACY =====================================') - print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) + print( + "====================================== ACCURACY =====================================" + ) + print( + "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) + ) -if __name__ == '__main__': +if __name__ == "__main__": args_parser = argparse.ArgumentParser() - args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') - args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') - args_parser.add_argument('--data_mode', type=str, required=True, default='dev') - args_parser.add_argument('--db_root_path', type=str, required=True, default='') - args_parser.add_argument('--num_cpus', type=int, default=1) - args_parser.add_argument('--meta_time_out', type=float, default=30.0) - args_parser.add_argument('--mode_gt', type=str, default='gt') - args_parser.add_argument('--mode_predict', type=str, default='gpt') - args_parser.add_argument('--difficulty', type=str, default='simple') - args_parser.add_argument('--diff_json_path', type=str, default='') + args_parser.add_argument( + "--predicted_sql_path", type=str, required=True, default="" + ) + args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") + args_parser.add_argument("--data_mode", type=str, required=True, default="dev") + args_parser.add_argument("--db_root_path", type=str, required=True, default="") + args_parser.add_argument("--num_cpus", type=int, default=1) + args_parser.add_argument("--meta_time_out", type=float, default=30.0) + args_parser.add_argument("--mode_gt", type=str, default="gt") + args_parser.add_argument("--mode_predict", type=str, default="gpt") + args_parser.add_argument("--difficulty", type=str, default="simple") + args_parser.add_argument("--diff_json_path", type=str, default="") args = args_parser.parse_args() exec_result = [] - pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode=args.mode_predict, - data_mode=args.data_mode) + pred_queries, db_paths = package_sqls( + args.predicted_sql_path, + args.db_root_path, + mode=args.mode_predict, + data_mode=args.data_mode, + ) # generate gt sqls: - gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', - data_mode=args.data_mode) + gt_queries, db_paths_gt = package_sqls( + args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode + ) if len(db_paths) == 0: db_paths = db_paths_gt query_pairs = list(zip(pred_queries, gt_queries)) - run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) + run_sqls_parallel( + query_pairs, + db_places=db_paths, + num_cpus=args.num_cpus, + meta_time_out=args.meta_time_out, + ) exec_result = sort_results(exec_result) - print('start calculate') - simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ - compute_acc_by_diff(exec_result, args.diff_json_path) + print("start calculate") + simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( + exec_result, args.diff_json_path + ) score_lists = [simple_acc, moderate_acc, challenging_acc, acc] print_data(score_lists, count_lists) - print('===========================================================================================') - print("Finished evaluation") + print( + "===========================================================================================" + ) + print("Finished evaluation") \ No newline at end of file diff --git a/dbgpt_hub/scripts/train_sft.sh b/dbgpt_hub/scripts/train_sft.sh index ae4b276..14eb4c2 100644 --- a/dbgpt_hub/scripts/train_sft.sh +++ b/dbgpt_hub/scripts/train_sft.sh @@ -5,11 +5,24 @@ train_log="dbgpt_hub/output/logs/train_sft_test_${current_date}.log" start_time=$(date +%s) echo " Train Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${train_log} +# # zero-shot +# num_shot=0 + +# one-shot train +num_shot=1 + +dataset="example_text2sql_train" +if [ "$num_shot" -eq 1 ]; then + dataset="example_text2sql_train_one_shot" +fi +model_name_or_path="Your_download_CodeLlama-13b-Instruct-hf_path" +output_dir="dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora" + # the default param set could be run in a server with one a100(40G) gpu, if your server not support the set,you can set smaller param such as lora_rank and use qlora with quant 4 eg... CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \ - --model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \ + --model_name_or_path $model_name_or_path \ --do_train \ - --dataset example_text2sql_train \ + --dataset $dataset \ --max_source_length 2048 \ --max_target_length 512 \ --finetuning_type lora \ @@ -17,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \ --template llama2 \ --lora_rank 64 \ --lora_alpha 32 \ - --output_dir dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora \ + --output_dir $output_dir \ --overwrite_cache \ --overwrite_output_dir \ --per_device_train_batch_size 1 \