diff --git a/dbgpt_hub/configs/config.py b/dbgpt_hub/configs/config.py index b3f4853..cf267c8 100644 --- a/dbgpt_hub/configs/config.py +++ b/dbgpt_hub/configs/config.py @@ -44,6 +44,7 @@ EXT2TYPE = {"csv": "csv", "json": "json", "jsonl": "json", "txt": "text"} # text2sql dataset information for processing sql data +# TODO: BIRD \ WiKiSQL \ ... SQL_DATA_INFO = [ { "data_source": "spider", @@ -53,6 +54,33 @@ "db_id_name": "db_id", "is_multiple_turn": False, } + , + { + "data_source": "chase", + "train_file": ["Chase/chase_train.json"], + "dev_file": ["Chase/chase_dev.json"], + "tables_file": "Chase/chase_tables.json", + "db_id_name": "database_id", + "is_multiple_turn": True, + } + , + { + "data_source": "cosql_dataset", + "train_file": ["sql_state_tracking/cosql_train.json"], + "dev_file": ["sql_state_tracking/cosql_dev.json"], + "tables_file": "tables.json", + "db_id_name": "database_id", + "is_multiple_turn": True, + } + , + { + "data_source": "sparc", + "train_file": ["train.json"], + "dev_file": ["dev.json"], + "tables_file": "tables.json", + "db_id_name": "database_id", + "is_multiple_turn": True, + } ] INSTRUCTION_PROMPT = """\ I want you to act as a SQL terminal in front of an example database, \ diff --git a/dbgpt_hub/data_process/sql_data_process.py b/dbgpt_hub/data_process/sql_data_process.py index 8b0286e..25a3b5f 100644 --- a/dbgpt_hub/data_process/sql_data_process.py +++ b/dbgpt_hub/data_process/sql_data_process.py @@ -17,15 +17,15 @@ class ProcessSqlData: - def __init__(self) -> None: - pass + def __init__(self, train_file=None, dev_file=None) -> None: + self.train_file = train_file + self.dev_file = dev_file - def decode_json_file(self, data_file_list, table_file, out_file): + def decode_json_file(self, data_file_list, table_file, db_id_name, is_multiple_turn=False): """ TO DO: 1.将相关prompt放入config中 2.将不同数据来源的字段信息放入config中 - 3.支持多轮对话数据集 """ if table_file.endswith(".jsonl"): @@ -84,49 +84,73 @@ def decode_json_file(self, data_file_list, table_file, out_file): db_dict[item["db_id"]] = source - # 单论对话 res = [] for data in tqdm(datas): - if data["db_id"] in db_dict.keys(): - input = { - "db_id": data["db_id"], - "instruction": INSTRUCTION_PROMPT.format(db_dict[data["db_id"]]), - "input": INPUT_PROMPT.format(data["question"]), - "output": data["query"], - "history": [], - } - res.append(input) - - with open(out_file, "w", encoding="utf-8") as s: - json.dump(res, s, indent=4, ensure_ascii=False) + if data[db_id_name] in db_dict.keys(): + if is_multiple_turn: #多轮 + history = [] + for interaction in data["interaction"]: + input = { + "db_id": data[db_id_name], + "instruction": INSTRUCTION_PROMPT.format(db_dict[data[db_id_name]]), + "input": INPUT_PROMPT.format(interaction["utterance"]), + "output": interaction["query"], + "history": history, + } + res.append(input) + history.append((INPUT_PROMPT.format(interaction["utterance"]), interaction["query"])) + else: # 单轮 + input = { + "db_id": data[db_id_name], + "instruction": INSTRUCTION_PROMPT.format(db_dict[data[db_id_name]]), + "input": INPUT_PROMPT.format(data["question"]), + "output": data["query"], + "history": [], + } + res.append(input) + return res def create_sft_raw_data(self): + train_data = [] + dev_data = [] for data_info in SQL_DATA_INFO: train_data_file_list = [ os.path.join(DATA_PATH, data_info["data_source"], file) for file in data_info["train_file"] ] - self.decode_json_file( - data_file_list=train_data_file_list, - table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["tables_file"] - ), - out_file=os.path.join(DATA_PATH, "example_text2sql_train.json"), + train_data.extend( + self.decode_json_file( + data_file_list=train_data_file_list, + table_file=os.path.join( + DATA_PATH, data_info["data_source"], data_info["tables_file"] + ), + db_id_name=data_info["db_id_name"], + is_multiple_turn=data_info['is_multiple_turn'] + ) ) dev_data_file_list = [ os.path.join(DATA_PATH, data_info["data_source"], file) for file in data_info["dev_file"] ] - self.decode_json_file( - data_file_list=dev_data_file_list, - table_file=os.path.join( - DATA_PATH, data_info["data_source"], data_info["tables_file"] - ), - out_file=os.path.join(DATA_PATH, "example_text2sql_dev.json"), + dev_data.extend( + self.decode_json_file( + data_file_list=dev_data_file_list, + table_file=os.path.join( + DATA_PATH, data_info["data_source"], data_info["tables_file"] + ), + db_id_name=data_info["db_id_name"], + is_multiple_turn=data_info['is_multiple_turn'] + ) ) + with open(self.train_file, "w", encoding="utf-8") as s: + json.dump(train_data, s, indent=4, ensure_ascii=False) + with open(self.dev_file, "w", encoding="utf-8") as s: + json.dump(dev_data, s, indent=4, ensure_ascii=False) if __name__ == "__main__": - precess = ProcessSqlData() + all_in_one_train_file = os.path.join(DATA_PATH, "example_text2sql_train.json") + all_in_one_dev_file = os.path.join(DATA_PATH, "example_text2sql_dev.json") + precess = ProcessSqlData(train_file=all_in_one_train_file, dev_file=all_in_one_dev_file) precess.create_sft_raw_data() diff --git a/requirements.txt b/requirements.txt index a8504c8..b21b37f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ nltk jsonlines pymysql pyyaml -json \ No newline at end of file +# json \ No newline at end of file