Skip to content

Commit

Permalink
feat: add one-shot train (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
junewgl authored Dec 28, 2023
1 parent d6a8bb9 commit c066049
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 6 deletions.
25 changes: 25 additions & 0 deletions dbgpt_hub/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@
##Instruction:\n{}\n"""
INPUT_PROMPT = "###Input:\n{}\n\n###Response:"

INSTRUCTION_ONE_SHOT_PROMPT = """\
I want you to act as a SQL terminal in front of an example database. \
You need only to return the sql command to me. \
First, I will show you few examples of an instruction followed by the correct SQL response. \
Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\
\n### Example1 Instruction:
The database contains tables such as employee, salary, and position. \
Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \
Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \
Table position has columns such as position_id, title, and department. position_id is the primary key. \
The employee_id of salary is the foreign key of employee_id of employee. \
The position_id of employee is the foreign key of position_id of position.\
\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
\n###New Instruction:\n{}\n"""

# EXAMPLES =[EXAMPLE1, EXAMPLE1]

# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
# \n###New Instruction:\n{}\n"

### test--------------------


# METHODS = ["full", "freeze", "lora"]

# STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"]
Expand Down
9 changes: 9 additions & 0 deletions dbgpt_hub/data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
"history": "history"
}
},
"example_text2sql_train_one_shot": {
"file_name": "example_text2sql_train_one_shot.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"example_rm_train": {
"file_name": "oaast_rm_zh.json",
"columns": {
Expand Down
26 changes: 23 additions & 3 deletions dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
DATA_PATH,
INPUT_PROMPT,
INSTRUCTION_PROMPT,
INSTRUCTION_ONE_SHOT_PROMPT,
)


class ProcessSqlData:
def __init__(self, train_file=None, dev_file=None) -> None:
def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.train_file = train_file
self.dev_file = dev_file
self.num_shot = num_shot

def decode_json_file(
self, data_file_list, table_file, db_id_name, is_multiple_turn=False
Expand Down Expand Up @@ -87,14 +89,18 @@ def decode_json_file(
db_dict[item["db_id"]] = source

res = []
base_instruction = INSTRUCTION_PROMPT
if self.num_shot == 1:
base_instruction = INSTRUCTION_ONE_SHOT_PROMPT

for data in tqdm(datas):
if data[db_id_name] in db_dict.keys():
if is_multiple_turn: # 多轮
history = []
for interaction in data["interaction"]:
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(
"instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(interaction["utterance"]),
Expand All @@ -111,7 +117,7 @@ def decode_json_file(
else: # 单轮
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(
"instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(data["question"]),
Expand Down Expand Up @@ -167,3 +173,17 @@ def create_sft_raw_data(self):
train_file=all_in_one_train_file, dev_file=all_in_one_dev_file
)
precess.create_sft_raw_data()

# one-shot
one_shot_all_in_one_train_file = os.path.join(
DATA_PATH, "example_text2sql_train_one_shot.json"
)
one_shot_all_in_one_dev_file = os.path.join(
DATA_PATH, "example_text2sql_dev_one_shot.json"
)
one_shot_precess = ProcessSqlData(
train_file=one_shot_all_in_one_train_file,
dev_file=one_shot_all_in_one_dev_file,
num_shot=1,
)
one_shot_precess.create_sft_raw_data()
19 changes: 16 additions & 3 deletions dbgpt_hub/scripts/train_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,32 @@ train_log="dbgpt_hub/output/logs/train_sft_test_${current_date}.log"
start_time=$(date +%s)
echo " Train Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${train_log}

# # zero-shot
# num_shot=0

# one-shot train
num_shot=1

dataset="example_text2sql_train"
if [ "$num_shot" -eq 1 ]; then
dataset="example_text2sql_train_one_shot"
fi
model_name_or_path="Your_download_CodeLlama-13b-Instruct-hf_path"
output_dir="dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora"

# the default param set could be run in a server with one a100(40G) gpu, if your server not support the set,you can set smaller param such as lora_rank and use qlora with quant 4 eg...
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub/train/sft_train.py \
--model_name_or_path Your_download_CodeLlama-13b-Instruct-hf_path \
--model_name_or_path $model_name_or_path \
--do_train \
--dataset example_text2sql_train \
--dataset $dataset \
--max_source_length 2048 \
--max_target_length 512 \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--template llama2 \
--lora_rank 64 \
--lora_alpha 32 \
--output_dir dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora \
--output_dir $output_dir \
--overwrite_cache \
--overwrite_output_dir \
--per_device_train_batch_size 1 \
Expand Down

0 comments on commit c066049

Please sign in to comment.