From c066049d02fd186bcfaceb5cb6d3986d0880cf2f Mon Sep 17 00:00:00 2001 From: junewgl <45283002+junewgl@users.noreply.github.com> Date: Thu, 28 Dec 2023 21:27:16 +0800 Subject: [PATCH] feat: add one-shot train (#197) --- dbgpt_hub/configs/config.py | 25 +++++++++++++++++++++ dbgpt_hub/data/dataset_info.json | 9 ++++++++ dbgpt_hub/data_process/sql_data_process.py | 26 +++++++++++++++++++--- dbgpt_hub/scripts/train_sft.sh | 19 +++++++++++++--- 4 files changed, 73 insertions(+), 6 deletions(-) diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index 04341a4..f5b2563 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -89,6 +89,31 @@ ##Instruction:\n{}\n""" INPUT_PROMPT = "###Input:\n{}\n\n###Response:" +INSTRUCTION_ONE_SHOT_PROMPT = """\ +I want you to act as a SQL terminal in front of an example database. \ +You need only to return the sql command to me. \ +First, I will show you few examples of an instruction followed by the correct SQL response. \ +Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\ +\n### Example1 Instruction: +The database contains tables such as employee, salary, and position. \ +Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \ +Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \ +Table position has columns such as position_id, title, and department. position_id is the primary key. \ +The employee_id of salary is the foreign key of employee_id of employee. \ +The position_id of employee is the foreign key of position_id of position.\ +\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +\n###New Instruction:\n{}\n""" + +# EXAMPLES =[EXAMPLE1, EXAMPLE1] + +# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +# \n###New Instruction:\n{}\n" + +### test-------------------- + + # METHODS = ["full", "freeze", "lora"] # STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"] diff --git a/dbgpt_hub/data/dataset_info.json b/dbgpt_hub/data/dataset_info.json index f5988a1..49d14c8 100644 --- a/dbgpt_hub/data/dataset_info.json +++ b/dbgpt_hub/data/dataset_info.json @@ -8,6 +8,15 @@ "history": "history" } }, + "example_text2sql_train_one_shot": { + "file_name": "example_text2sql_train_one_shot.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output", + "history": "history" + } + }, "example_rm_train": { "file_name": "oaast_rm_zh.json", "columns": { diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index a3e30c7..e869f2e 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -13,13 +13,15 @@ DATA_PATH, INPUT_PROMPT, INSTRUCTION_PROMPT, + INSTRUCTION_ONE_SHOT_PROMPT, ) class ProcessSqlData: - def __init__(self, train_file=None, dev_file=None) -> None: + def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None: self.train_file = train_file self.dev_file = dev_file + self.num_shot = num_shot def decode_json_file( self, data_file_list, table_file, db_id_name, is_multiple_turn=False @@ -87,6 +89,10 @@ def decode_json_file( db_dict[item["db_id"]] = source res = [] + base_instruction = INSTRUCTION_PROMPT + if self.num_shot == 1: + base_instruction = INSTRUCTION_ONE_SHOT_PROMPT + for data in tqdm(datas): if data[db_id_name] in db_dict.keys(): if is_multiple_turn: # 多轮 @@ -94,7 +100,7 @@ def decode_json_file( for interaction in data["interaction"]: input = { "db_id": data[db_id_name], - "instruction": INSTRUCTION_PROMPT.format( + "instruction": base_instruction.format( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(interaction["utterance"]), @@ -111,7 +117,7 @@ def decode_json_file( else: # 单轮 input = { "db_id": data[db_id_name], - "instruction": INSTRUCTION_PROMPT.format( + "instruction": base_instruction.format( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(data["question"]), @@ -167,3 +173,17 @@ def create_sft_raw_data(self): train_file=all_in_one_train_file, dev_file=all_in_one_dev_file ) precess.create_sft_raw_data() + + # one-shot + one_shot_all_in_one_train_file = os.path.join( + DATA_PATH, "example_text2sql_train_one_shot.json" + ) + one_shot_all_in_one_dev_file = os.path.join( + DATA_PATH, "example_text2sql_dev_one_shot.json" + ) + one_shot_precess = ProcessSqlData( + train_file=one_shot_all_in_one_train_file, + dev_file=one_shot_all_in_one_dev_file, + num_shot=1, + ) + one_shot_precess.create_sft_raw_data() diff --git a/dbgpt_hub/scripts/train_sft.sh b/dbgpt_hub/scripts/train_sft.sh index ae4b276..14eb4c2 100644 --- a/dbgpt_hub/scripts/train_sft.sh +++ b/dbgpt_hub/scripts/train_sft.sh @@ -5,11 +5,24 @@ train_log="dbgpt_hub/output/logs/train_sft_test_${current_date}.log" start_time=$(date +%s) echo " Train Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${train_log} +# # zero-shot +# num_shot=0 + +# one-shot train +num_shot=1 + +dataset="example_text2sql_train" +if [ "$num_shot" -eq 1 ]; then + dataset="example_text2sql_train_one_shot" +fi +model_name_or_path="Your_download_CodeLlama-13b-Instruct-hf_path" +output_dir="dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora" + # the default param set could be run in a server with one a100(40G) gpu, if your server not support the set,you can set smaller param such as lora_rank and use qlora with quant 4 eg... CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \ - --model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \ + --model_name_or_path $model_name_or_path \ --do_train \ - --dataset example_text2sql_train \ + --dataset $dataset \ --max_source_length 2048 \ --max_target_length 512 \ --finetuning_type lora \ @@ -17,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \ --template llama2 \ --lora_rank 64 \ --lora_alpha 32 \ - --output_dir dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora \ + --output_dir $output_dir \ --overwrite_cache \ --overwrite_output_dir \ --per_device_train_batch_size 1 \