From 45d2c5564d94ec16972b287665e1488a8a4fedf9 Mon Sep 17 00:00:00 2001
From: junewgl <1965259211@qq.com>
Date: Mon, 25 Dec 2023 17:34:30 +0800
Subject: [PATCH 1/4] docs: update baseline result
---
README.md | 261 +++++++++++++++++++++++++++++++++++++++++++++++++++
README.zh.md | 261 +++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 522 insertions(+)
diff --git a/README.md b/README.md
index 6d098a9..d5fd01f 100644
--- a/README.md
+++ b/README.md
@@ -27,6 +27,267 @@
[**简体中文**](README.zh.md) | [**Discord**](https://discord.gg/7uQnPuveTY) | [**Wechat**](https://github.com/eosphoros-ai/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC) | [**Huggingface**](https://huggingface.co/eosphoros) | [**Community**](https://github.com/eosphoros-ai/community)
+## Baseline
+- update time: 2023/12/08
+- metric: execution accuracy (ex)
+- more details refer to [docs/eval-llm-result.md](https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/docs/eval_llm_result.md)
+
+
+
+ Model |
+ Method |
+ Easy |
+ Medium |
+ Hard |
+ Extra |
+ All |
+
+
+ |
+ base |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+
+
+ Llama2-7B-Chat |
+ lora |
+ 0.887 |
+ 0.641 |
+ 0.489 |
+ 0.331 |
+ 0.626 |
+
+
+ |
+ qlora |
+ 0.847 |
+ 0.623 |
+ 0.466 |
+ 0.361 |
+ 0.608 |
+
+
+ |
+ base |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+
+
+ Llama2-13B-Chat |
+ lora |
+ 0.907 |
+ 0.729 |
+ 0.552 |
+ 0.343 |
+ 0.68 |
+
+
+ |
+ qlora |
+ 0.911 |
+ 0.7 |
+ 0.552 |
+ 0.319 |
+ 0.664 |
+
+
+ |
+ base |
+ 0.214 |
+ 0.177 |
+ 0.092 |
+ 0.036 |
+ 0.149 |
+
+
+ CodeLlama-7B-Instruct |
+ lora |
+ 0.923 |
+ 0.756 |
+ 0.586 |
+ 0.349 |
+ 0.702 |
+
+
+ |
+ qlora |
+ 0.911 |
+ 0.751 |
+ 0.598 |
+ 0.331 |
+ 0.696 |
+
+
+ |
+ base |
+ 0.698 |
+ 0.601 |
+ 0.408 |
+ 0.271 |
+ 0.539 |
+
+
+ CodeLlama-13B-Instruct |
+ lora |
+ 0.94 |
+ 0.789 |
+ 0.684 |
+ 0.404 |
+ 0.746 |
+
+
+ |
+ qlora |
+ 0.94 |
+ 0.774 |
+ 0.626 |
+ 0.392 |
+ 0.727 |
+
+
+ |
+ base |
+ 0.577 |
+ 0.352 |
+ 0.201 |
+ 0.066 |
+ 335 |
+
+
+ Baichuan2-7B-Chat |
+ lora |
+ 0.871 |
+ 0.63 |
+ 0.448 |
+ 0.295 |
+ 0.603 |
+
+
+ |
+ qlora |
+ 0.891 |
+ 0.637 |
+ 0.489 |
+ 0.331 |
+ 0.624 |
+
+
+ |
+ base |
+ 0.581 |
+ 0.413 |
+ 0.264 |
+ 0.187 |
+ 0.392 |
+
+
+ Baichuan2-13B-Chat |
+ lora |
+ 0.903 |
+ 0.702 |
+ 0.569 |
+ 0.392 |
+ 0.678 |
+
+
+ |
+ qlora |
+ 0.895 |
+ 0.675 |
+ 0.58 |
+ 0.343 |
+ 0.659 |
+
+
+ |
+ base |
+ 0.395 |
+ 0.256 |
+ 0.138 |
+ 0.042 |
+ 0.235 |
+
+
+Qwen-7B-Chat |
+ lora |
+ 0.855 |
+ 0.688 |
+ 0.575 |
+ 0.331 |
+ 0.652 |
+
+
+ |
+ qlora |
+ 0.911 |
+ 0.675 |
+ 0.575 |
+ 0.343 |
+ 0.662 |
+
+
+ |
+ base |
+ 0.871 |
+ 0.632 |
+ 0.368 |
+ 0.181 |
+ 0.573 |
+
+
+ Qwen-14B-Chat |
+ lora |
+ 0.895 |
+ 0.702 |
+ 0.552 |
+ 0.331 |
+ 0.663 |
+
+
+ |
+ qlora |
+ 0.919 |
+ 0.744 |
+ 0.598 |
+ 0.367 |
+ 0.701 |
+
+
+ |
+ base |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+
+
+ ChatGLM3-6b |
+ lora |
+ 0.855 |
+ 0.605 |
+ 0.477 |
+ 0.271 |
+ 0.59 |
+
+
+ |
+ qlora |
+ 0.843 |
+ 0.603 |
+ 0.506 |
+ 0.211 |
+ 0.581 |
+
+
+
+
## Contents
- [DB-GPT-Hub: Text-to-SQL parsing with LLMs](#db-gpt-hub-text-to-sql-parsing-with-llms)
- [Contents](#contents)
diff --git a/README.zh.md b/README.zh.md
index 7dbe21d..4657a1c 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -26,6 +26,267 @@
[**英文**](README.md) | [**Discord**](https://discord.gg/7uQnPuveTY) | [**Wechat**](https://github.com/eosphoros-ai/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC) | [**Huggingface**](https://huggingface.co/eosphoros) | [**Community**](https://github.com/eosphoros-ai/community)
+
+## Baseline
+- 更新日期: 2023/12/08
+- 评价指标: execution accuracy (ex)
+- 详情参考[docs/eval-llm-result.md](https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/docs/eval_llm_result.md)
+
+
+
+ Model |
+ Method |
+ Easy |
+ Medium |
+ Hard |
+ Extra |
+ All |
+
+
+ |
+ base |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+
+
+ Llama2-7B-Chat |
+ lora |
+ 0.887 |
+ 0.641 |
+ 0.489 |
+ 0.331 |
+ 0.626 |
+
+
+ |
+ qlora |
+ 0.847 |
+ 0.623 |
+ 0.466 |
+ 0.361 |
+ 0.608 |
+
+
+ |
+ base |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+
+
+ Llama2-13B-Chat |
+ lora |
+ 0.907 |
+ 0.729 |
+ 0.552 |
+ 0.343 |
+ 0.68 |
+
+
+ |
+ qlora |
+ 0.911 |
+ 0.7 |
+ 0.552 |
+ 0.319 |
+ 0.664 |
+
+
+ |
+ base |
+ 0.214 |
+ 0.177 |
+ 0.092 |
+ 0.036 |
+ 0.149 |
+
+
+ CodeLlama-7B-Instruct |
+ lora |
+ 0.923 |
+ 0.756 |
+ 0.586 |
+ 0.349 |
+ 0.702 |
+
+
+ |
+ qlora |
+ 0.911 |
+ 0.751 |
+ 0.598 |
+ 0.331 |
+ 0.696 |
+
+
+ |
+ base |
+ 0.698 |
+ 0.601 |
+ 0.408 |
+ 0.271 |
+ 0.539 |
+
+
+ CodeLlama-13B-Instruct |
+ lora |
+ 0.94 |
+ 0.789 |
+ 0.684 |
+ 0.404 |
+ 0.746 |
+
+
+ |
+ qlora |
+ 0.94 |
+ 0.774 |
+ 0.626 |
+ 0.392 |
+ 0.727 |
+
+
+ |
+ base |
+ 0.577 |
+ 0.352 |
+ 0.201 |
+ 0.066 |
+ 335 |
+
+
+ Baichuan2-7B-Chat |
+ lora |
+ 0.871 |
+ 0.63 |
+ 0.448 |
+ 0.295 |
+ 0.603 |
+
+
+ |
+ qlora |
+ 0.891 |
+ 0.637 |
+ 0.489 |
+ 0.331 |
+ 0.624 |
+
+
+ |
+ base |
+ 0.581 |
+ 0.413 |
+ 0.264 |
+ 0.187 |
+ 0.392 |
+
+
+ Baichuan2-13B-Chat |
+ lora |
+ 0.903 |
+ 0.702 |
+ 0.569 |
+ 0.392 |
+ 0.678 |
+
+
+ |
+ qlora |
+ 0.895 |
+ 0.675 |
+ 0.58 |
+ 0.343 |
+ 0.659 |
+
+
+ |
+ base |
+ 0.395 |
+ 0.256 |
+ 0.138 |
+ 0.042 |
+ 0.235 |
+
+
+Qwen-7B-Chat |
+ lora |
+ 0.855 |
+ 0.688 |
+ 0.575 |
+ 0.331 |
+ 0.652 |
+
+
+ |
+ qlora |
+ 0.911 |
+ 0.675 |
+ 0.575 |
+ 0.343 |
+ 0.662 |
+
+
+ |
+ base |
+ 0.871 |
+ 0.632 |
+ 0.368 |
+ 0.181 |
+ 0.573 |
+
+
+ Qwen-14B-Chat |
+ lora |
+ 0.895 |
+ 0.702 |
+ 0.552 |
+ 0.331 |
+ 0.663 |
+
+
+ |
+ qlora |
+ 0.919 |
+ 0.744 |
+ 0.598 |
+ 0.367 |
+ 0.701 |
+
+
+ |
+ base |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+ 0 |
+
+
+ ChatGLM3-6b |
+ lora |
+ 0.855 |
+ 0.605 |
+ 0.477 |
+ 0.271 |
+ 0.59 |
+
+
+ |
+ qlora |
+ 0.843 |
+ 0.603 |
+ 0.506 |
+ 0.211 |
+ 0.581 |
+
+
+
## Contents
- [DB-GPT-Hub:利用LLMs实现Text-to-SQL](#db-gpt-hub利用llms实现text-to-sql)
- [Contents](#contents)
From ed9b2cb733930af740ce6b8ebfb97441146399c4 Mon Sep 17 00:00:00 2001
From: junewgl <1965259211@qq.com>
Date: Wed, 27 Dec 2023 17:56:41 +0800
Subject: [PATCH 2/4] feat: support one-shot
---
dbgpt_hub/configs/config.py | 25 ++++++++++++++++++++++
dbgpt_hub/data/dataset_info.json | 9 ++++++++
dbgpt_hub/data_process/sql_data_process.py | 22 ++++++++++++++-----
dbgpt_hub/scripts/train_sft.sh | 19 +++++++++++++---
4 files changed, 67 insertions(+), 8 deletions(-)
diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py
index 04341a4..f5b2563 100644
--- a/dbgpt_hub/configs/config.py
+++ b/dbgpt_hub/configs/config.py
@@ -89,6 +89,31 @@
##Instruction:\n{}\n"""
INPUT_PROMPT = "###Input:\n{}\n\n###Response:"
+INSTRUCTION_ONE_SHOT_PROMPT = """\
+I want you to act as a SQL terminal in front of an example database. \
+You need only to return the sql command to me. \
+First, I will show you few examples of an instruction followed by the correct SQL response. \
+Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\
+\n### Example1 Instruction:
+The database contains tables such as employee, salary, and position. \
+Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \
+Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \
+Table position has columns such as position_id, title, and department. position_id is the primary key. \
+The employee_id of salary is the foreign key of employee_id of employee. \
+The position_id of employee is the foreign key of position_id of position.\
+\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
+\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
+\n###New Instruction:\n{}\n"""
+
+# EXAMPLES =[EXAMPLE1, EXAMPLE1]
+
+# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
+# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
+# \n###New Instruction:\n{}\n"
+
+### test--------------------
+
+
# METHODS = ["full", "freeze", "lora"]
# STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"]
diff --git a/dbgpt_hub/data/dataset_info.json b/dbgpt_hub/data/dataset_info.json
index f5988a1..49d14c8 100644
--- a/dbgpt_hub/data/dataset_info.json
+++ b/dbgpt_hub/data/dataset_info.json
@@ -8,6 +8,15 @@
"history": "history"
}
},
+ "example_text2sql_train_one_shot": {
+ "file_name": "example_text2sql_train_one_shot.json",
+ "columns": {
+ "prompt": "instruction",
+ "query": "input",
+ "response": "output",
+ "history": "history"
+ }
+ },
"example_rm_train": {
"file_name": "oaast_rm_zh.json",
"columns": {
diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py
index a3e30c7..b5c6734 100644
--- a/dbgpt_hub/data_process/sql_data_process.py
+++ b/dbgpt_hub/data_process/sql_data_process.py
@@ -13,14 +13,14 @@
DATA_PATH,
INPUT_PROMPT,
INSTRUCTION_PROMPT,
+ INSTRUCTION_ONE_SHOT_PROMPT,
)
-
class ProcessSqlData:
- def __init__(self, train_file=None, dev_file=None) -> None:
+ def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.train_file = train_file
self.dev_file = dev_file
-
+ self.num_shot = num_shot
def decode_json_file(
self, data_file_list, table_file, db_id_name, is_multiple_turn=False
):
@@ -87,6 +87,10 @@ def decode_json_file(
db_dict[item["db_id"]] = source
res = []
+ base_instruction = INSTRUCTION_PROMPT
+ if self.num_shot == 1:
+ base_instruction = INSTRUCTION_ONE_SHOT_PROMPT
+
for data in tqdm(datas):
if data[db_id_name] in db_dict.keys():
if is_multiple_turn: # 多轮
@@ -94,7 +98,7 @@ def decode_json_file(
for interaction in data["interaction"]:
input = {
"db_id": data[db_id_name],
- "instruction": INSTRUCTION_PROMPT.format(
+ "instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(interaction["utterance"]),
@@ -111,7 +115,7 @@ def decode_json_file(
else: # 单轮
input = {
"db_id": data[db_id_name],
- "instruction": INSTRUCTION_PROMPT.format(
+ "instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(data["question"]),
@@ -167,3 +171,11 @@ def create_sft_raw_data(self):
train_file=all_in_one_train_file, dev_file=all_in_one_dev_file
)
precess.create_sft_raw_data()
+
+ # one-shot
+ one_shot_all_in_one_train_file = os.path.join(DATA_PATH, "example_text2sql_train_one_shot.json")
+ one_shot_all_in_one_dev_file = os.path.join(DATA_PATH, "example_text2sql_dev_one_shot.json")
+ one_shot_precess = ProcessSqlData(
+ train_file=one_shot_all_in_one_train_file, dev_file=one_shot_all_in_one_dev_file, num_shot=1
+ )
+ one_shot_precess.create_sft_raw_data()
\ No newline at end of file
diff --git a/dbgpt_hub/scripts/train_sft.sh b/dbgpt_hub/scripts/train_sft.sh
index ae4b276..14eb4c2 100644
--- a/dbgpt_hub/scripts/train_sft.sh
+++ b/dbgpt_hub/scripts/train_sft.sh
@@ -5,11 +5,24 @@ train_log="dbgpt_hub/output/logs/train_sft_test_${current_date}.log"
start_time=$(date +%s)
echo " Train Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${train_log}
+# # zero-shot
+# num_shot=0
+
+# one-shot train
+num_shot=1
+
+dataset="example_text2sql_train"
+if [ "$num_shot" -eq 1 ]; then
+ dataset="example_text2sql_train_one_shot"
+fi
+model_name_or_path="Your_download_CodeLlama-13b-Instruct-hf_path"
+output_dir="dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora"
+
# the default param set could be run in a server with one a100(40G) gpu, if your server not support the set,you can set smaller param such as lora_rank and use qlora with quant 4 eg...
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
- --model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \
+ --model_name_or_path $model_name_or_path \
--do_train \
- --dataset example_text2sql_train \
+ --dataset $dataset \
--max_source_length 2048 \
--max_target_length 512 \
--finetuning_type lora \
@@ -17,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
--template llama2 \
--lora_rank 64 \
--lora_alpha 32 \
- --output_dir dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora \
+ --output_dir $output_dir \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 1 \
From 7051a0e5617f672c30af85673e948bb3a96b55b5 Mon Sep 17 00:00:00 2001
From: junewgl <1965259211@qq.com>
Date: Wed, 27 Dec 2023 18:02:13 +0800
Subject: [PATCH 3/4] feat: add one-shot train
---
dbgpt_hub/configs/config.py | 25 +++++++++++++++++++++
dbgpt_hub/data/dataset_info.json | 9 ++++++++
dbgpt_hub/data_process/sql_data_process.py | 26 +++++++++++++++++++---
dbgpt_hub/scripts/train_sft.sh | 19 +++++++++++++---
4 files changed, 73 insertions(+), 6 deletions(-)
diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py
index 04341a4..f5b2563 100644
--- a/dbgpt_hub/configs/config.py
+++ b/dbgpt_hub/configs/config.py
@@ -89,6 +89,31 @@
##Instruction:\n{}\n"""
INPUT_PROMPT = "###Input:\n{}\n\n###Response:"
+INSTRUCTION_ONE_SHOT_PROMPT = """\
+I want you to act as a SQL terminal in front of an example database. \
+You need only to return the sql command to me. \
+First, I will show you few examples of an instruction followed by the correct SQL response. \
+Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\
+\n### Example1 Instruction:
+The database contains tables such as employee, salary, and position. \
+Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \
+Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \
+Table position has columns such as position_id, title, and department. position_id is the primary key. \
+The employee_id of salary is the foreign key of employee_id of employee. \
+The position_id of employee is the foreign key of position_id of position.\
+\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
+\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
+\n###New Instruction:\n{}\n"""
+
+# EXAMPLES =[EXAMPLE1, EXAMPLE1]
+
+# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
+# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
+# \n###New Instruction:\n{}\n"
+
+### test--------------------
+
+
# METHODS = ["full", "freeze", "lora"]
# STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"]
diff --git a/dbgpt_hub/data/dataset_info.json b/dbgpt_hub/data/dataset_info.json
index f5988a1..49d14c8 100644
--- a/dbgpt_hub/data/dataset_info.json
+++ b/dbgpt_hub/data/dataset_info.json
@@ -8,6 +8,15 @@
"history": "history"
}
},
+ "example_text2sql_train_one_shot": {
+ "file_name": "example_text2sql_train_one_shot.json",
+ "columns": {
+ "prompt": "instruction",
+ "query": "input",
+ "response": "output",
+ "history": "history"
+ }
+ },
"example_rm_train": {
"file_name": "oaast_rm_zh.json",
"columns": {
diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py
index a3e30c7..e869f2e 100644
--- a/dbgpt_hub/data_process/sql_data_process.py
+++ b/dbgpt_hub/data_process/sql_data_process.py
@@ -13,13 +13,15 @@
DATA_PATH,
INPUT_PROMPT,
INSTRUCTION_PROMPT,
+ INSTRUCTION_ONE_SHOT_PROMPT,
)
class ProcessSqlData:
- def __init__(self, train_file=None, dev_file=None) -> None:
+ def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.train_file = train_file
self.dev_file = dev_file
+ self.num_shot = num_shot
def decode_json_file(
self, data_file_list, table_file, db_id_name, is_multiple_turn=False
@@ -87,6 +89,10 @@ def decode_json_file(
db_dict[item["db_id"]] = source
res = []
+ base_instruction = INSTRUCTION_PROMPT
+ if self.num_shot == 1:
+ base_instruction = INSTRUCTION_ONE_SHOT_PROMPT
+
for data in tqdm(datas):
if data[db_id_name] in db_dict.keys():
if is_multiple_turn: # 多轮
@@ -94,7 +100,7 @@ def decode_json_file(
for interaction in data["interaction"]:
input = {
"db_id": data[db_id_name],
- "instruction": INSTRUCTION_PROMPT.format(
+ "instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(interaction["utterance"]),
@@ -111,7 +117,7 @@ def decode_json_file(
else: # 单轮
input = {
"db_id": data[db_id_name],
- "instruction": INSTRUCTION_PROMPT.format(
+ "instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(data["question"]),
@@ -167,3 +173,17 @@ def create_sft_raw_data(self):
train_file=all_in_one_train_file, dev_file=all_in_one_dev_file
)
precess.create_sft_raw_data()
+
+ # one-shot
+ one_shot_all_in_one_train_file = os.path.join(
+ DATA_PATH, "example_text2sql_train_one_shot.json"
+ )
+ one_shot_all_in_one_dev_file = os.path.join(
+ DATA_PATH, "example_text2sql_dev_one_shot.json"
+ )
+ one_shot_precess = ProcessSqlData(
+ train_file=one_shot_all_in_one_train_file,
+ dev_file=one_shot_all_in_one_dev_file,
+ num_shot=1,
+ )
+ one_shot_precess.create_sft_raw_data()
diff --git a/dbgpt_hub/scripts/train_sft.sh b/dbgpt_hub/scripts/train_sft.sh
index ae4b276..14eb4c2 100644
--- a/dbgpt_hub/scripts/train_sft.sh
+++ b/dbgpt_hub/scripts/train_sft.sh
@@ -5,11 +5,24 @@ train_log="dbgpt_hub/output/logs/train_sft_test_${current_date}.log"
start_time=$(date +%s)
echo " Train Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${train_log}
+# # zero-shot
+# num_shot=0
+
+# one-shot train
+num_shot=1
+
+dataset="example_text2sql_train"
+if [ "$num_shot" -eq 1 ]; then
+ dataset="example_text2sql_train_one_shot"
+fi
+model_name_or_path="Your_download_CodeLlama-13b-Instruct-hf_path"
+output_dir="dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora"
+
# the default param set could be run in a server with one a100(40G) gpu, if your server not support the set,you can set smaller param such as lora_rank and use qlora with quant 4 eg...
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
- --model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \
+ --model_name_or_path $model_name_or_path \
--do_train \
- --dataset example_text2sql_train \
+ --dataset $dataset \
--max_source_length 2048 \
--max_target_length 512 \
--finetuning_type lora \
@@ -17,7 +30,7 @@ CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
--template llama2 \
--lora_rank 64 \
--lora_alpha 32 \
- --output_dir dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora \
+ --output_dir $output_dir \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 1 \
From 4242f3a323906bc6f4ab0eefc1af1009dda2b111 Mon Sep 17 00:00:00 2001
From: junewgl <1965259211@qq.com>
Date: Thu, 28 Dec 2023 10:52:00 +0800
Subject: [PATCH 4/4] perf: code style
---
dbgpt_hub/data_process/sql_data_process.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py
index 3739a9e..e869f2e 100644
--- a/dbgpt_hub/data_process/sql_data_process.py
+++ b/dbgpt_hub/data_process/sql_data_process.py
@@ -16,11 +16,13 @@
INSTRUCTION_ONE_SHOT_PROMPT,
)
+
class ProcessSqlData:
def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.train_file = train_file
self.dev_file = dev_file
self.num_shot = num_shot
+
def decode_json_file(
self, data_file_list, table_file, db_id_name, is_multiple_turn=False
):