Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add one-shot train #197

Merged
merged 6 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions dbgpt_hub/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@
##Instruction:\n{}\n"""
INPUT_PROMPT = "###Input:\n{}\n\n###Response:"

INSTRUCTION_ONE_SHOT_PROMPT = """\
I want you to act as a SQL terminal in front of an example database. \
You need only to return the sql command to me. \
First, I will show you few examples of an instruction followed by the correct SQL response. \
Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\
\n### Example1 Instruction:
The database contains tables such as employee, salary, and position. \
Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \
Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \
Table position has columns such as position_id, title, and department. position_id is the primary key. \
The employee_id of salary is the foreign key of employee_id of employee. \
The position_id of employee is the foreign key of position_id of position.\
\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
\n###New Instruction:\n{}\n"""

# EXAMPLES =[EXAMPLE1, EXAMPLE1]

# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
# \n###New Instruction:\n{}\n"

### test--------------------


# METHODS = ["full", "freeze", "lora"]

# STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"]
Expand Down
9 changes: 9 additions & 0 deletions dbgpt_hub/data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
"history": "history"
}
},
"example_text2sql_train_one_shot": {
"file_name": "example_text2sql_train_one_shot.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"example_rm_train": {
"file_name": "oaast_rm_zh.json",
"columns": {
Expand Down
26 changes: 23 additions & 3 deletions dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
DATA_PATH,
INPUT_PROMPT,
INSTRUCTION_PROMPT,
INSTRUCTION_ONE_SHOT_PROMPT,
)


class ProcessSqlData:
def __init__(self, train_file=None, dev_file=None) -> None:
def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.train_file = train_file
self.dev_file = dev_file
self.num_shot = num_shot

def decode_json_file(
self, data_file_list, table_file, db_id_name, is_multiple_turn=False
Expand Down Expand Up @@ -87,14 +89,18 @@ def decode_json_file(
db_dict[item["db_id"]] = source

res = []
base_instruction = INSTRUCTION_PROMPT
if self.num_shot == 1:
base_instruction = INSTRUCTION_ONE_SHOT_PROMPT

for data in tqdm(datas):
if data[db_id_name] in db_dict.keys():
if is_multiple_turn: # 多轮
history = []
for interaction in data["interaction"]:
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(
"instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(interaction["utterance"]),
Expand All @@ -111,7 +117,7 @@ def decode_json_file(
else: # 单轮
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(
"instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(data["question"]),
Expand Down Expand Up @@ -167,3 +173,17 @@ def create_sft_raw_data(self):
train_file=all_in_one_train_file, dev_file=all_in_one_dev_file
)
precess.create_sft_raw_data()

# one-shot
one_shot_all_in_one_train_file = os.path.join(
DATA_PATH, "example_text2sql_train_one_shot.json"
)
one_shot_all_in_one_dev_file = os.path.join(
DATA_PATH, "example_text2sql_dev_one_shot.json"
)
one_shot_precess = ProcessSqlData(
train_file=one_shot_all_in_one_train_file,
dev_file=one_shot_all_in_one_dev_file,
num_shot=1,
)
one_shot_precess.create_sft_raw_data()
19 changes: 16 additions & 3 deletions dbgpt_hub/scripts/train_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,32 @@ train_log="dbgpt_hub/output/logs/train_sft_test_${current_date}.log"
start_time=$(date +%s)
echo " Train Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${train_log}

# # zero-shot
# num_shot=0

# one-shot train
num_shot=1

dataset="example_text2sql_train"
if [ "$num_shot" -eq 1 ]; then
dataset="example_text2sql_train_one_shot"
fi
model_name_or_path="Your_download_CodeLlama-13b-Instruct-hf_path"
output_dir="dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora"

# the default param set could be run in a server with one a100(40G) gpu, if your server not support the set,you can set smaller param such as lora_rank and use qlora with quant 4 eg...
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
--model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \
--model_name_or_path $model_name_or_path \
--do_train \
--dataset example_text2sql_train \
--dataset $dataset \
--max_source_length 2048 \
--max_target_length 512 \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--template llama2 \
--lora_rank 64 \
--lora_alpha 32 \
--output_dir dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora \
--output_dir $output_dir \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 1 \
Expand Down
Loading