From 45d2c5564d94ec16972b287665e1488a8a4fedf9 Mon Sep 17 00:00:00 2001 From: junewgl <1965259211@qq.com> Date: Mon, 25 Dec 2023 17:34:30 +0800 Subject: [PATCH 1/4] docs: update baseline result --- README.md | 261 +++++++++++++++++++++++++++++++++++++++++++++++++++ README.zh.md | 261 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 522 insertions(+) diff --git a/README.md b/README.md index 6d098a9..d5fd01f 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,267 @@ [**简体中文**](README.zh.md) | [**Discord**](https://discord.gg/7uQnPuveTY) | [**Wechat**](https://github.com/eosphoros-ai/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC) | [**Huggingface**](https://huggingface.co/eosphoros) | [**Community**](https://github.com/eosphoros-ai/community) +## Baseline +- update time: 2023/12/08 +- metric: execution accuracy (ex) +- more details refer to [docs/eval-llm-result.md](https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/docs/eval_llm_result.md) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelMethodEasyMediumHardExtraAll
base00000
Llama2-7B-Chatlora0.8870.6410.4890.3310.626
qlora0.8470.6230.4660.3610.608
base00000
Llama2-13B-Chatlora0.9070.7290.5520.3430.68
qlora0.9110.70.5520.3190.664
base0.2140.1770.0920.0360.149
CodeLlama-7B-Instructlora0.9230.7560.5860.3490.702
qlora0.9110.7510.5980.3310.696
base0.6980.6010.4080.2710.539
CodeLlama-13B-Instructlora0.940.7890.6840.4040.746
qlora0.940.7740.6260.3920.727
base0.5770.3520.2010.066335
Baichuan2-7B-Chatlora0.8710.630.4480.2950.603
qlora0.8910.6370.4890.3310.624
base0.5810.4130.2640.1870.392
Baichuan2-13B-Chatlora0.9030.7020.5690.3920.678
qlora0.8950.6750.580.3430.659
base0.3950.2560.1380.0420.235
Qwen-7B-Chatlora0.8550.6880.5750.3310.652
qlora0.9110.6750.5750.3430.662
base0.8710.6320.3680.1810.573
Qwen-14B-Chatlora0.8950.7020.5520.3310.663
qlora0.9190.7440.5980.3670.701
base00000
ChatGLM3-6blora0.8550.6050.4770.2710.59
qlora0.8430.6030.5060.2110.581
+ + ## Contents - [DB-GPT-Hub: Text-to-SQL parsing with LLMs](#db-gpt-hub-text-to-sql-parsing-with-llms) - [Contents](#contents) diff --git a/README.zh.md b/README.zh.md index 7dbe21d..4657a1c 100644 --- a/README.zh.md +++ b/README.zh.md @@ -26,6 +26,267 @@ [**英文**](README.md) | [**Discord**](https://discord.gg/7uQnPuveTY) | [**Wechat**](https://github.com/eosphoros-ai/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC) | [**Huggingface**](https://huggingface.co/eosphoros) | [**Community**](https://github.com/eosphoros-ai/community) + +## Baseline +- 更新日期: 2023/12/08 +- 评价指标: execution accuracy (ex) +- 详情参考[docs/eval-llm-result.md](https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/docs/eval_llm_result.md) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelMethodEasyMediumHardExtraAll
base00000
Llama2-7B-Chatlora0.8870.6410.4890.3310.626
qlora0.8470.6230.4660.3610.608
base00000
Llama2-13B-Chatlora0.9070.7290.5520.3430.68
qlora0.9110.70.5520.3190.664
base0.2140.1770.0920.0360.149
CodeLlama-7B-Instructlora0.9230.7560.5860.3490.702
qlora0.9110.7510.5980.3310.696
base0.6980.6010.4080.2710.539
CodeLlama-13B-Instructlora0.940.7890.6840.4040.746
qlora0.940.7740.6260.3920.727
base0.5770.3520.2010.066335
Baichuan2-7B-Chatlora0.8710.630.4480.2950.603
qlora0.8910.6370.4890.3310.624
base0.5810.4130.2640.1870.392
Baichuan2-13B-Chatlora0.9030.7020.5690.3920.678
qlora0.8950.6750.580.3430.659
base0.3950.2560.1380.0420.235
Qwen-7B-Chatlora0.8550.6880.5750.3310.652
qlora0.9110.6750.5750.3430.662
base0.8710.6320.3680.1810.573
Qwen-14B-Chatlora0.8950.7020.5520.3310.663
qlora0.9190.7440.5980.3670.701
base00000
ChatGLM3-6blora0.8550.6050.4770.2710.59
qlora0.8430.6030.5060.2110.581
+ ## Contents - [DB-GPT-Hub:利用LLMs实现Text-to-SQL](#db-gpt-hub利用llms实现text-to-sql) - [Contents](#contents) From ed9b2cb733930af740ce6b8ebfb97441146399c4 Mon Sep 17 00:00:00 2001 From: junewgl <1965259211@qq.com> Date: Wed, 27 Dec 2023 17:56:41 +0800 Subject: [PATCH 2/4] feat: support one-shot --- dbgpt_hub/configs/config.py | 25 ++++++++++++++++++++++ dbgpt_hub/data/dataset_info.json | 9 ++++++++ dbgpt_hub/data_process/sql_data_process.py | 22 ++++++++++++++----- dbgpt_hub/scripts/train_sft.sh | 19 +++++++++++++--- 4 files changed, 67 insertions(+), 8 deletions(-) diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index 04341a4..f5b2563 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -89,6 +89,31 @@ ##Instruction:\n{}\n""" INPUT_PROMPT = "###Input:\n{}\n\n###Response:" +INSTRUCTION_ONE_SHOT_PROMPT = """\ +I want you to act as a SQL terminal in front of an example database. \ +You need only to return the sql command to me. \ +First, I will show you few examples of an instruction followed by the correct SQL response. \ +Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\ +\n### Example1 Instruction: +The database contains tables such as employee, salary, and position. \ +Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \ +Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \ +Table position has columns such as position_id, title, and department. position_id is the primary key. \ +The employee_id of salary is the foreign key of employee_id of employee. \ +The position_id of employee is the foreign key of position_id of position.\ +\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +\n###New Instruction:\n{}\n""" + +# EXAMPLES =[EXAMPLE1, EXAMPLE1] + +# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +# \n###New Instruction:\n{}\n" + +### test-------------------- + + # METHODS = ["full", "freeze", "lora"] # STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"] diff --git a/dbgpt_hub/data/dataset_info.json b/dbgpt_hub/data/dataset_info.json index f5988a1..49d14c8 100644 --- a/dbgpt_hub/data/dataset_info.json +++ b/dbgpt_hub/data/dataset_info.json @@ -8,6 +8,15 @@ "history": "history" } }, + "example_text2sql_train_one_shot": { + "file_name": "example_text2sql_train_one_shot.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output", + "history": "history" + } + }, "example_rm_train": { "file_name": "oaast_rm_zh.json", "columns": { diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index a3e30c7..b5c6734 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -13,14 +13,14 @@ DATA_PATH, INPUT_PROMPT, INSTRUCTION_PROMPT, + INSTRUCTION_ONE_SHOT_PROMPT, ) - class ProcessSqlData: - def __init__(self, train_file=None, dev_file=None) -> None: + def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None: self.train_file = train_file self.dev_file = dev_file - + self.num_shot = num_shot def decode_json_file( self, data_file_list, table_file, db_id_name, is_multiple_turn=False ): @@ -87,6 +87,10 @@ def decode_json_file( db_dict[item["db_id"]] = source res = [] + base_instruction = INSTRUCTION_PROMPT + if self.num_shot == 1: + base_instruction = INSTRUCTION_ONE_SHOT_PROMPT + for data in tqdm(datas): if data[db_id_name] in db_dict.keys(): if is_multiple_turn: # 多轮 @@ -94,7 +98,7 @@ def decode_json_file( for interaction in data["interaction"]: input = { "db_id": data[db_id_name], - "instruction": INSTRUCTION_PROMPT.format( + "instruction": base_instruction.format( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(interaction["utterance"]), @@ -111,7 +115,7 @@ def decode_json_file( else: # 单轮 input = { "db_id": data[db_id_name], - "instruction": INSTRUCTION_PROMPT.format( + "instruction": base_instruction.format( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(data["question"]), @@ -167,3 +171,11 @@ def create_sft_raw_data(self): train_file=all_in_one_train_file, dev_file=all_in_one_dev_file ) precess.create_sft_raw_data() + + # one-shot + one_shot_all_in_one_train_file = os.path.join(DATA_PATH, "example_text2sql_train_one_shot.json") + one_shot_all_in_one_dev_file = os.path.join(DATA_PATH, "example_text2sql_dev_one_shot.json") + one_shot_precess = ProcessSqlData( + train_file=one_shot_all_in_one_train_file, dev_file=one_shot_all_in_one_dev_file, num_shot=1 + ) + one_shot_precess.create_sft_raw_data() \ No newline at end of file diff --git a/dbgpt_hub/scripts/train_sft.sh b/dbgpt_hub/scripts/train_sft.sh index ae4b276..14eb4c2 100644 --- a/dbgpt_hub/scripts/train_sft.sh +++ b/dbgpt_hub/scripts/train_sft.sh @@ -5,11 +5,24 @@ train_log="dbgpt_hub/output/logs/train_sft_test_${current_date}.log" start_time=$(date +%s) echo " Train Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${train_log} +# # zero-shot +# num_shot=0 + +# one-shot train +num_shot=1 + +dataset="example_text2sql_train" +if [ "$num_shot" -eq 1 ]; then + dataset="example_text2sql_train_one_shot" +fi +model_name_or_path="Your_download_CodeLlama-13b-Instruct-hf_path" +output_dir="dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora" + # the default param set could be run in a server with one a100(40G) gpu, if your server not support the set,you can set smaller param such as lora_rank and use qlora with quant 4 eg... CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \ - --model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \ + --model_name_or_path $model_name_or_path \ --do_train \ - --dataset example_text2sql_train \ + --dataset $dataset \ --max_source_length 2048 \ --max_target_length 512 \ --finetuning_type lora \ @@ -17,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \ --template llama2 \ --lora_rank 64 \ --lora_alpha 32 \ - --output_dir dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora \ + --output_dir $output_dir \ --overwrite_cache \ --overwrite_output_dir \ --per_device_train_batch_size 1 \ From 7051a0e5617f672c30af85673e948bb3a96b55b5 Mon Sep 17 00:00:00 2001 From: junewgl <1965259211@qq.com> Date: Wed, 27 Dec 2023 18:02:13 +0800 Subject: [PATCH 3/4] feat: add one-shot train --- dbgpt_hub/configs/config.py | 25 +++++++++++++++++++++ dbgpt_hub/data/dataset_info.json | 9 ++++++++ dbgpt_hub/data_process/sql_data_process.py | 26 +++++++++++++++++++--- dbgpt_hub/scripts/train_sft.sh | 19 +++++++++++++--- 4 files changed, 73 insertions(+), 6 deletions(-) diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index 04341a4..f5b2563 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -89,6 +89,31 @@ ##Instruction:\n{}\n""" INPUT_PROMPT = "###Input:\n{}\n\n###Response:" +INSTRUCTION_ONE_SHOT_PROMPT = """\ +I want you to act as a SQL terminal in front of an example database. \ +You need only to return the sql command to me. \ +First, I will show you few examples of an instruction followed by the correct SQL response. \ +Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\ +\n### Example1 Instruction: +The database contains tables such as employee, salary, and position. \ +Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \ +Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \ +Table position has columns such as position_id, title, and department. position_id is the primary key. \ +The employee_id of salary is the foreign key of employee_id of employee. \ +The position_id of employee is the foreign key of position_id of position.\ +\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +\n###New Instruction:\n{}\n""" + +# EXAMPLES =[EXAMPLE1, EXAMPLE1] + +# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\ +# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\ +# \n###New Instruction:\n{}\n" + +### test-------------------- + + # METHODS = ["full", "freeze", "lora"] # STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"] diff --git a/dbgpt_hub/data/dataset_info.json b/dbgpt_hub/data/dataset_info.json index f5988a1..49d14c8 100644 --- a/dbgpt_hub/data/dataset_info.json +++ b/dbgpt_hub/data/dataset_info.json @@ -8,6 +8,15 @@ "history": "history" } }, + "example_text2sql_train_one_shot": { + "file_name": "example_text2sql_train_one_shot.json", + "columns": { + "prompt": "instruction", + "query": "input", + "response": "output", + "history": "history" + } + }, "example_rm_train": { "file_name": "oaast_rm_zh.json", "columns": { diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index a3e30c7..e869f2e 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -13,13 +13,15 @@ DATA_PATH, INPUT_PROMPT, INSTRUCTION_PROMPT, + INSTRUCTION_ONE_SHOT_PROMPT, ) class ProcessSqlData: - def __init__(self, train_file=None, dev_file=None) -> None: + def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None: self.train_file = train_file self.dev_file = dev_file + self.num_shot = num_shot def decode_json_file( self, data_file_list, table_file, db_id_name, is_multiple_turn=False @@ -87,6 +89,10 @@ def decode_json_file( db_dict[item["db_id"]] = source res = [] + base_instruction = INSTRUCTION_PROMPT + if self.num_shot == 1: + base_instruction = INSTRUCTION_ONE_SHOT_PROMPT + for data in tqdm(datas): if data[db_id_name] in db_dict.keys(): if is_multiple_turn: # 多轮 @@ -94,7 +100,7 @@ def decode_json_file( for interaction in data["interaction"]: input = { "db_id": data[db_id_name], - "instruction": INSTRUCTION_PROMPT.format( + "instruction": base_instruction.format( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(interaction["utterance"]), @@ -111,7 +117,7 @@ def decode_json_file( else: # 单轮 input = { "db_id": data[db_id_name], - "instruction": INSTRUCTION_PROMPT.format( + "instruction": base_instruction.format( db_dict[data[db_id_name]] ), "input": INPUT_PROMPT.format(data["question"]), @@ -167,3 +173,17 @@ def create_sft_raw_data(self): train_file=all_in_one_train_file, dev_file=all_in_one_dev_file ) precess.create_sft_raw_data() + + # one-shot + one_shot_all_in_one_train_file = os.path.join( + DATA_PATH, "example_text2sql_train_one_shot.json" + ) + one_shot_all_in_one_dev_file = os.path.join( + DATA_PATH, "example_text2sql_dev_one_shot.json" + ) + one_shot_precess = ProcessSqlData( + train_file=one_shot_all_in_one_train_file, + dev_file=one_shot_all_in_one_dev_file, + num_shot=1, + ) + one_shot_precess.create_sft_raw_data() diff --git a/dbgpt_hub/scripts/train_sft.sh b/dbgpt_hub/scripts/train_sft.sh index ae4b276..14eb4c2 100644 --- a/dbgpt_hub/scripts/train_sft.sh +++ b/dbgpt_hub/scripts/train_sft.sh @@ -5,11 +5,24 @@ train_log="dbgpt_hub/output/logs/train_sft_test_${current_date}.log" start_time=$(date +%s) echo " Train Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${train_log} +# # zero-shot +# num_shot=0 + +# one-shot train +num_shot=1 + +dataset="example_text2sql_train" +if [ "$num_shot" -eq 1 ]; then + dataset="example_text2sql_train_one_shot" +fi +model_name_or_path="Your_download_CodeLlama-13b-Instruct-hf_path" +output_dir="dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora" + # the default param set could be run in a server with one a100(40G) gpu, if your server not support the set,you can set smaller param such as lora_rank and use qlora with quant 4 eg... CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \ - --model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \ + --model_name_or_path $model_name_or_path \ --do_train \ - --dataset example_text2sql_train \ + --dataset $dataset \ --max_source_length 2048 \ --max_target_length 512 \ --finetuning_type lora \ @@ -17,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \ --template llama2 \ --lora_rank 64 \ --lora_alpha 32 \ - --output_dir dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora \ + --output_dir $output_dir \ --overwrite_cache \ --overwrite_output_dir \ --per_device_train_batch_size 1 \ From 4242f3a323906bc6f4ab0eefc1af1009dda2b111 Mon Sep 17 00:00:00 2001 From: junewgl <1965259211@qq.com> Date: Thu, 28 Dec 2023 10:52:00 +0800 Subject: [PATCH 4/4] perf: code style --- dbgpt_hub/data_process/sql_data_process.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index 3739a9e..e869f2e 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -16,11 +16,13 @@ INSTRUCTION_ONE_SHOT_PROMPT, ) + class ProcessSqlData: def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None: self.train_file = train_file self.dev_file = dev_file self.num_shot = num_shot + def decode_json_file( self, data_file_list, table_file, db_id_name, is_multiple_turn=False ):