Skip to content

Commit

Permalink
Merge branch 'feature/bird_zf' of https://github.com/moutozf/DB-GPT-Hub
Browse files Browse the repository at this point in the history
… into feature/bird_zf

and change dataset to spider
# Conflicts:
#	dbgpt_hub/data_process/sql_data_process.py
#	dbgpt_hub/eval/evaluation_bird.py
  • Loading branch information
moutozf committed Jan 1, 2024
2 parents 77f9ead + 44d43fd commit 4d17d97
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 76 deletions.
Binary file modified assets/wechat.JPG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 37 additions & 12 deletions dbgpt_hub/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,23 @@
# TODO: BIRD \ WiKiSQL \ ...
SQL_DATA_INFO = [
{
"data_source": "bird",
"train_file": ["train/train.json"],
"dev_file": ["dev/dev.json"],
"train_tables_file": "train/train_tables.json",
"dev_tables_file": "dev/dev_tables.json",
"data_source": "spider",
"train_file": ["train_spider.json", "train_others.json"],
"dev_file": ["dev.json"],
"train_tables_file": "tables.json",
"dev_tables_file": "tables.json",
"db_id_name": "db_id",
"output_name": "SQL",
"output_name": "query",
"is_multiple_turn": False,
}
# {
# "data_source": "spider",
# "train_file": ["train_spider.json", "train_others.json"],
# "dev_file": ["dev.json"],
# "train_tables_file": "tables.json",
# "dev_tables_file": "tables.json",
# "data_source": "bird",
# "train_file": ["train/train.json"],
# "dev_file": ["dev/dev.json"],
# "train_tables_file": "train/train_tables.json",
# "dev_tables_file": "dev/dev_tables.json",
# "db_id_name": "db_id",
# "output_name": "query",
# "output_name": "SQL",
# "is_multiple_turn": False,
# }
# ,
Expand Down Expand Up @@ -101,6 +101,31 @@
##Instruction:\n{}\n"""
INPUT_PROMPT = "###Input:\n{}\n\n###Response:"

INSTRUCTION_ONE_SHOT_PROMPT = """\
I want you to act as a SQL terminal in front of an example database. \
You need only to return the sql command to me. \
First, I will show you few examples of an instruction followed by the correct SQL response. \
Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\
\n### Example1 Instruction:
The database contains tables such as employee, salary, and position. \
Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \
Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \
Table position has columns such as position_id, title, and department. position_id is the primary key. \
The employee_id of salary is the foreign key of employee_id of employee. \
The position_id of employee is the foreign key of position_id of position.\
\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
\n###New Instruction:\n{}\n"""

# EXAMPLES =[EXAMPLE1, EXAMPLE1]

# EXAMPLE1 = "\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
# \n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
# \n###New Instruction:\n{}\n"

### test--------------------


# METHODS = ["full", "freeze", "lora"]

# STAGES = ["SFT", "Reward Modeling", "PPO", "DPO", "Pre-Training"]
Expand Down
9 changes: 9 additions & 0 deletions dbgpt_hub/data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
"history": "history"
}
},
"example_text2sql_train_one_shot": {
"file_name": "example_text2sql_train_one_shot.json",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"example_rm_train": {
"file_name": "oaast_rm_zh.json",
"columns": {
Expand Down
55 changes: 42 additions & 13 deletions dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@
DATA_PATH,
INPUT_PROMPT,
INSTRUCTION_PROMPT,
INSTRUCTION_ONE_SHOT_PROMPT,
)


class ProcessSqlData:
def __init__(self, train_file=None, dev_file=None) -> None:
def __init__(self, train_file=None, dev_file=None, num_shot=0) -> None:
self.train_file = train_file
self.dev_file = dev_file
self.num_shot = num_shot

def decode_json_file(
self, data_file_list, table_file, db_id_name, output_name, is_multiple_turn=False
self,
data_file_list,
table_file,
db_id_name,
output_name,
is_multiple_turn=False,
):
"""
TO DO:
Expand Down Expand Up @@ -66,9 +73,9 @@ def decode_json_file(
if type(primary_key[j]) == int:
if coloumns[primary_key[j] - 1][0] == i:
source += (
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
+ "\n"
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
+ "\n"
)
# combination primary key
elif type(primary_key[j]) == list:
Expand All @@ -78,10 +85,10 @@ def decode_json_file(
if coloumns[primary_key[j][k] - 1][0] == i:
keys.append(coloumns[primary_key[j][k] - 1][1])
source += (
combine_p +
", ".join(keys)
+ ") are the primary key."
+ "\n"
combine_p
+ ", ".join(keys)
+ ") are the primary key."
+ "\n"
)
else:
print("not support type", type(primary_key[j]))
Expand All @@ -104,14 +111,18 @@ def decode_json_file(
db_dict[item["db_id"]] = source

res = []
base_instruction = INSTRUCTION_PROMPT
if self.num_shot == 1:
base_instruction = INSTRUCTION_ONE_SHOT_PROMPT

for data in tqdm(datas):
if data[db_id_name] in db_dict.keys():
if is_multiple_turn: # 多轮
history = []
for interaction in data["interaction"]:
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(
"instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(interaction["utterance"]),
Expand All @@ -128,7 +139,7 @@ def decode_json_file(
else: # 单轮
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(
"instruction": base_instruction.format(
db_dict[data[db_id_name]]
),
"input": INPUT_PROMPT.format(data["question"]),
Expand All @@ -150,7 +161,9 @@ def create_sft_raw_data(self):
self.decode_json_file(
data_file_list=train_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["train_tables_file"]
DATA_PATH,
data_info["data_source"],
data_info["train_tables_file"],
),
db_id_name=data_info["db_id_name"],
output_name=data_info["output_name"],
Expand All @@ -166,7 +179,9 @@ def create_sft_raw_data(self):
self.decode_json_file(
data_file_list=dev_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["dev_tables_file"]
DATA_PATH,
data_info["data_source"],
data_info["dev_tables_file"],
),
db_id_name=data_info["db_id_name"],
output_name=data_info["output_name"],
Expand All @@ -186,3 +201,17 @@ def create_sft_raw_data(self):
train_file=all_in_one_train_file, dev_file=all_in_one_dev_file
)
precess.create_sft_raw_data()

# one-shot
one_shot_all_in_one_train_file = os.path.join(
DATA_PATH, "example_text2sql_train_one_shot.json"
)
one_shot_all_in_one_dev_file = os.path.join(
DATA_PATH, "example_text2sql_dev_one_shot.json"
)
one_shot_precess = ProcessSqlData(
train_file=one_shot_all_in_one_train_file,
dev_file=one_shot_all_in_one_dev_file,
num_shot=1,
)
one_shot_precess.create_sft_raw_data()
Loading

0 comments on commit 4d17d97

Please sign in to comment.