Skip to content

Commit

Permalink
Merge pull request #123 from oushu1zhangxiangxuan1/debug
Browse files Browse the repository at this point in the history
#122: support multi-turn datasets(chase\sparc\cosql), and merge all data together
  • Loading branch information
wangzaistone authored Nov 8, 2023
2 parents 610f7e7 + d524303 commit 01f982a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 31 deletions.
28 changes: 28 additions & 0 deletions dbgpt_hub/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
EXT2TYPE = {"csv": "csv", "json": "json", "jsonl": "json", "txt": "text"}

# text2sql dataset information for processing sql data
# TODO: BIRD \ WiKiSQL \ ...
SQL_DATA_INFO = [
{
"data_source": "spider",
Expand All @@ -53,6 +54,33 @@
"db_id_name": "db_id",
"is_multiple_turn": False,
}
,
{
"data_source": "chase",
"train_file": ["Chase/chase_train.json"],
"dev_file": ["Chase/chase_dev.json"],
"tables_file": "Chase/chase_tables.json",
"db_id_name": "database_id",
"is_multiple_turn": True,
}
,
{
"data_source": "cosql_dataset",
"train_file": ["sql_state_tracking/cosql_train.json"],
"dev_file": ["sql_state_tracking/cosql_dev.json"],
"tables_file": "tables.json",
"db_id_name": "database_id",
"is_multiple_turn": True,
}
,
{
"data_source": "sparc",
"train_file": ["train.json"],
"dev_file": ["dev.json"],
"tables_file": "tables.json",
"db_id_name": "database_id",
"is_multiple_turn": True,
}
]
INSTRUCTION_PROMPT = """\
I want you to act as a SQL terminal in front of an example database, \
Expand Down
84 changes: 54 additions & 30 deletions dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@


class ProcessSqlData:
def __init__(self) -> None:
pass
def __init__(self, train_file=None, dev_file=None) -> None:
self.train_file = train_file
self.dev_file = dev_file

def decode_json_file(self, data_file_list, table_file, out_file):
def decode_json_file(self, data_file_list, table_file, db_id_name, is_multiple_turn=False):
"""
TO DO:
1.将相关prompt放入config中
2.将不同数据来源的字段信息放入config中
3.支持多轮对话数据集
"""

if table_file.endswith(".jsonl"):
Expand Down Expand Up @@ -84,49 +84,73 @@ def decode_json_file(self, data_file_list, table_file, out_file):

db_dict[item["db_id"]] = source

# 单论对话
res = []
for data in tqdm(datas):
if data["db_id"] in db_dict.keys():
input = {
"db_id": data["db_id"],
"instruction": INSTRUCTION_PROMPT.format(db_dict[data["db_id"]]),
"input": INPUT_PROMPT.format(data["question"]),
"output": data["query"],
"history": [],
}
res.append(input)

with open(out_file, "w", encoding="utf-8") as s:
json.dump(res, s, indent=4, ensure_ascii=False)
if data[db_id_name] in db_dict.keys():
if is_multiple_turn: #多轮
history = []
for interaction in data["interaction"]:
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(db_dict[data[db_id_name]]),
"input": INPUT_PROMPT.format(interaction["utterance"]),
"output": interaction["query"],
"history": history,
}
res.append(input)
history.append((INPUT_PROMPT.format(interaction["utterance"]), interaction["query"]))
else: # 单轮
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(db_dict[data[db_id_name]]),
"input": INPUT_PROMPT.format(data["question"]),
"output": data["query"],
"history": [],
}
res.append(input)
return res

def create_sft_raw_data(self):
train_data = []
dev_data = []
for data_info in SQL_DATA_INFO:
train_data_file_list = [
os.path.join(DATA_PATH, data_info["data_source"], file)
for file in data_info["train_file"]
]
self.decode_json_file(
data_file_list=train_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
),
out_file=os.path.join(DATA_PATH, "example_text2sql_train.json"),
train_data.extend(
self.decode_json_file(
data_file_list=train_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
),
db_id_name=data_info["db_id_name"],
is_multiple_turn=data_info['is_multiple_turn']
)
)

dev_data_file_list = [
os.path.join(DATA_PATH, data_info["data_source"], file)
for file in data_info["dev_file"]
]
self.decode_json_file(
data_file_list=dev_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
),
out_file=os.path.join(DATA_PATH, "example_text2sql_dev.json"),
dev_data.extend(
self.decode_json_file(
data_file_list=dev_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
),
db_id_name=data_info["db_id_name"],
is_multiple_turn=data_info['is_multiple_turn']
)
)
with open(self.train_file, "w", encoding="utf-8") as s:
json.dump(train_data, s, indent=4, ensure_ascii=False)
with open(self.dev_file, "w", encoding="utf-8") as s:
json.dump(dev_data, s, indent=4, ensure_ascii=False)


if __name__ == "__main__":
precess = ProcessSqlData()
all_in_one_train_file = os.path.join(DATA_PATH, "example_text2sql_train.json")
all_in_one_dev_file = os.path.join(DATA_PATH, "example_text2sql_dev.json")
precess = ProcessSqlData(train_file=all_in_one_train_file, dev_file=all_in_one_dev_file)
precess.create_sft_raw_data()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ nltk
jsonlines
pymysql
pyyaml
json
# json

0 comments on commit 01f982a

Please sign in to comment.