From e5ce64c50080cc45d3d416ea11614be9ec583caa Mon Sep 17 00:00:00 2001 From: qidanrui Date: Mon, 25 Dec 2023 08:17:48 +0000 Subject: [PATCH 1/3] update bitstandbytes --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index d4528d6..a9607d0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -262,13 +262,13 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte [[package]] name = "bitsandbytes" -version = "0.41.2" +version = "0.41.3.post2" description = "k-bit optimizers and matrix multiplication routines." optional = false python-versions = "*" files = [ - {file = "bitsandbytes-0.41.2-py3-none-any.whl", hash = "sha256:5a2280761dc11c7a23a1be948cfd6a849c2e718012ee34316b979eb6c5634de2"}, - {file = "bitsandbytes-0.41.2.tar.gz", hash = "sha256:787c14b63cc559e1b344f683497a9353ac2e256a3fe89972f960e7c428d5cce7"}, + {file = "bitsandbytes-0.41.3.post2-py3-none-any.whl", hash = "sha256:ceb301a3d4e6bf52bdad8d09f3064ac194bdfdeae535994c0315bd2ef7639cca"}, + {file = "bitsandbytes-0.41.3.post2.tar.gz", hash = "sha256:7d25a51fb3b74b58e569473f8b70a5239124c0593dc053479c41cf2cd6730502"}, ] [[package]] @@ -4285,4 +4285,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "31cd031e60d15eb9bdca0afc4ce0fb8261ee48990035089d4f29d69bc23b702a" +content-hash = "131fd0b341707b7ea14e4fef407141e550b1692e54b00a33d107123373a218de" diff --git a/pyproject.toml b/pyproject.toml index b6d6018..44fd75f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,9 +57,9 @@ nvidia-nccl-cu12 = "2.18.1" nvidia-nvtx-cu12 = "12.1.105" triton = "2.1.0" nvidia-nvjitlink-cu12 = "^12.3.52" -bitsandbytes = "0.41.2" prettytable = "^3.9.0" docopt = "^0.6.2" +bitsandbytes = "0.41.3.post2" [build-system] requires = ["poetry-core"] From e60772a7a7251abbe4669c7f9d37859e6789820d Mon Sep 17 00:00:00 2001 From: qidanrui Date: Thu, 28 Dec 2023 08:39:39 +0000 Subject: [PATCH 2/3] add gpt synthetic dataset generator --- dbgpt_hub/data_generator/__init__.py | 8 ++ dbgpt_hub/data_generator/gpt_generator.py | 125 ++++++++++++++++++ dbgpt_hub/data_generator/gpt_generator_api.py | 38 ++++++ dbgpt_hub/data_generator/llm_generator.py | 23 ++++ dbgpt_hub/data_generator/utils.py | 35 +++++ poetry.lock | 36 ++++- pyproject.toml | 1 + 7 files changed, 265 insertions(+), 1 deletion(-) create mode 100644 dbgpt_hub/data_generator/__init__.py create mode 100644 dbgpt_hub/data_generator/gpt_generator.py create mode 100644 dbgpt_hub/data_generator/gpt_generator_api.py create mode 100644 dbgpt_hub/data_generator/llm_generator.py create mode 100644 dbgpt_hub/data_generator/utils.py diff --git a/dbgpt_hub/data_generator/__init__.py b/dbgpt_hub/data_generator/__init__.py new file mode 100644 index 0000000..dee6f5c --- /dev/null +++ b/dbgpt_hub/data_generator/__init__.py @@ -0,0 +1,8 @@ +""" +dbgpt_hub.data_generator +============== +""" + +from .gpt_generator_api import generate_dataset_with_gpt + +__all__ = ["generate_dataset_with_gpt"] diff --git a/dbgpt_hub/data_generator/gpt_generator.py b/dbgpt_hub/data_generator/gpt_generator.py new file mode 100644 index 0000000..82e5c59 --- /dev/null +++ b/dbgpt_hub/data_generator/gpt_generator.py @@ -0,0 +1,125 @@ +from openai import OpenAI + +import os +import json + +from tqdm import tqdm + +from llm_generator import LLMGenerator +from utils import COT_PROMPT, FEW_SHOTS_EXAMPLE + +class GPTGenerator(LLMGenerator): + + def __init__( + self, + model: str = "gpt-3.5-turbo-16k", + model_temperature: int = 0, + max_tokens: int = 2048, + prompt: str = "", + num_text2sql_pair_each_db: int = 10, + table_file_path: str = "", + output_path: str = "" + ): + if len(table_file_path) > 0: + self.table_file_path = table_file_path + else: + self.table_file_path = "../data/spider/tables.json" + + if len(output_path) > 0: + self.output_path = output_path + else: + self.output_path = "../data/spider/synthetic_data_with_gpt.json" + + if len(prompt) > 0: + self.prompt = prompt + else: + self.prompt = COT_PROMPT + self.model = model + self.model_temperature = model_temperature + self.max_tokens = max_tokens + self.num_text2sql_pair_each_db = num_text2sql_pair_each_db + + self.synthetic_dataset = [] + + def generate_synthetic_dataset(self): + """Function for generating synthetic dataset. + By default, we generate Spider-like synthetic dataset. + """ + schema = "" + synthetic_dataset = [] + + tables = json.load(open(self.table_file_path)) + db_num = len(tables) + easy_count = int(self.num_text2sql_pair_each_db / db_num) + medium_count = int(self.num_text2sql_pair_each_db / db_num) + hard_count = self.num_text2sql_pair_each_db - easy_count - medium_count + + db_dict = {} + for item in tqdm(tables[:]): + tables = item["table_names_original"] + coloumns = item["column_names_original"][1:] + primary_key = item["primary_keys"] + foreign_keys = item["foreign_keys"] + schema = item["db_id"] + " database contains tables such as " + ", ".join(tables) + ". " + for i, name in enumerate(tables): + data = [coloumn[1] for coloumn in coloumns if coloumn[0] == i] + schema += "Table " + name + " has columns such as " + ", ".join(data) + ". " + + # get primary key info + for j in range(len(primary_key)): + if coloumns[primary_key[j]-1][0] == i: + schema += coloumns[primary_key[j]-1][1] + " is the primary key." + "\n" + + # get foreign key info + for key in foreign_keys: + schema += "The " + coloumns[key[0]-1][1] + " of " + tables[coloumns[key[0]-1][0]] + \ + " is the foreign key of " + coloumns[key[1]-1][1] + \ + " of " + tables[coloumns[key[1]-1][0]] + ".\n" + + db_dict[item["db_id"]] = schema + + try: + # Single generated data for one DB + for k in range(self.num_text2sql_pair_each_db): + text2sql_pair = self._chat_llm( + self.prompt.format( + easy_count=easy_count, + medium_count=medium_count, + hard_count=hard_count, + schema=schema, + few_shots_example=FEW_SHOTS_EXAMPLE) + ) + text2sql_pair = eval(text2sql_pair) + synthetic_dataset += text2sql_pair + except: + continue + + self.synthetic_dataset = synthetic_dataset + self._writeout_dataset() + + def _chat_llm(self, prompt): + client = OpenAI( + api_key=os.environ['OPENAI_API_KEY'], + ) + + completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + model=self.model, + temperature=self.model_temperature, + max_tokens=self.max_tokens, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + return(completion.choices[0].message.content) + + def _writeout_dataset(self): + with open(self.output_path, "w", encoding="utf-8") as s: + json.dump(self.synthetic_dataset, s, indent=4, ensure_ascii=False) + + diff --git a/dbgpt_hub/data_generator/gpt_generator_api.py b/dbgpt_hub/data_generator/gpt_generator_api.py new file mode 100644 index 0000000..5cb9f4e --- /dev/null +++ b/dbgpt_hub/data_generator/gpt_generator_api.py @@ -0,0 +1,38 @@ + +import os +import sys + +from typing import Optional, Dict, Any +from gpt_generator import GPTGenerator +from utils import COT_PROMPT + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +def generate_dataset_with_gpt( + args: Optional[Dict[str, Any]] = None +): + # Default Arguments + if args is None: + args = { + "model": "gpt-3.5-turbo-16k", + "prompt": COT_PROMPT, + "num_text2sql_pair_each_db": 1, + "table_file_path": os.path.join(ROOT_PATH, "dbgpt_hub/data/spider/tables.json"), + "output_path": os.path.join(ROOT_PATH, "dbgpt_hub/data/spider/synthetic_data_with_gpt.json") + } + else: + args = args + + # Run GPT Generator + gpt_generator = GPTGenerator( + model=args["model"], + prompt=args["prompt"], + num_text2sql_pair_each_db=args["num_text2sql_pair_each_db"], + table_file_path=args["table_file_path"], + output_path=args["output_path"] + ) + gpt_generator.generate_synthetic_dataset() + +if __name__ == "__main__": + generate_dataset_with_gpt() diff --git a/dbgpt_hub/data_generator/llm_generator.py b/dbgpt_hub/data_generator/llm_generator.py new file mode 100644 index 0000000..eb18a74 --- /dev/null +++ b/dbgpt_hub/data_generator/llm_generator.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod + +from typing import Any, Dict, List, Optional, Tuple, Union + +class LLMGenerator(ABC): + """An interface for large language model data generator. + A LLM data generator can accept prompts and generate synthetic Text2SQL dataset. + """ + + @abstractmethod + def generate_synthetic_dataset(self): + """Function for generating synthetic dataset""" + pass + + @abstractmethod + def _chat_llm(self): + """Function for interacting with LLMs""" + pass + + @abstractmethod + def _writeout_dataset(self): + """Function for writing out generated dataset""" + pass diff --git a/dbgpt_hub/data_generator/utils.py b/dbgpt_hub/data_generator/utils.py new file mode 100644 index 0000000..c4365a8 --- /dev/null +++ b/dbgpt_hub/data_generator/utils.py @@ -0,0 +1,35 @@ +COT_PROMPT = """ +Please use the following database information, genarate different difficulty level of natural language questions with their corresponding SQL querires. +There are three different difficulty levels: +When generating natual language questions and their corresponding SQL queries, you should consider different SQL operators such as WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, JOIN, INTERSECT, EXCEPT, UNION, NOT IN, OR, AND, EXISTS, LIKE as well as nested queries. +Moreover, please make sure that each table in the database appears in at least one query. + +There are three different difficulty levels, which are defined as follows: +Easy: Queries that require basic filtering or aggregation on a single table. +Medium: Queries that encompass more complex filtering or aggregation and involve joining multiple tables. +Hard: Queries that entail advanced filtering or aggregation, multiple joins, and the use of subqueries. + +Here is the basic information of database: {schema} + +Based on the tables, columns, primary keys, foreign keys and different difficulty levels, generate {easy_count} Easy, {medium_count} Medium, and {hard_count} Hard natural language questions with their correlated SQL queries. +Provide your answer in JSON form. Reply with only the answer in JSON form and include no other commentary: +RESPONSE FORMAT: +{few_shots_example} + +The "db_id" in the above examples means the name of used database. Do not fill out it with "_database" suffix. +""" + +FEW_SHOTS_EXAMPLE = """ +[ + { + 'db_id': 'music_2', + 'question': 'Who performed the song named "Le Pop"?', + 'query': 'SELECT T2.firstname, T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1.bandmate=T2.id JOIN Songs AS T3 ON T3.SongId=T1.SongId WHERE T3.Title="Le Pop"' + }, + { + 'db_id': 'insurance_fnol', + 'question': 'Tell me the types of the policy used by the customer named "Dayana Robel".', + 'query': 'SELECT DISTINCT t3.policy_type_code FROM customers AS t1 JOIN customers_policies AS t2 ON t1.customer_id=t2.customer_id JOIN available_policies AS t3 ON t2.policy_id=t3.policy_id WHERE t1.customer_name="Dayana Robel"' + } +] +""" \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index a9607d0..d723ea9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -616,6 +616,17 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "docker-pycreds" version = "0.4.0" @@ -1946,6 +1957,29 @@ files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "openai" +version = "1.6.1" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.6.1-py3-none-any.whl", hash = "sha256:bc9f774838d67ac29fb24cdeb2d58faf57de8b311085dcd1348f7aa02a96c7ee"}, + {file = "openai-1.6.1.tar.gz", hash = "sha256:d553ca9dbf9486b08e75b09e8671e4f638462aaadccfced632bf490fc3d75fa2"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "orjson" version = "3.9.10" @@ -4285,4 +4319,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "131fd0b341707b7ea14e4fef407141e550b1692e54b00a33d107123373a218de" +content-hash = "b973d23812596d562bf030c5029e0a548e6da44516de2b98cb5735a4ef271a09" diff --git a/pyproject.toml b/pyproject.toml index 44fd75f..c1e7312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ nvidia-nvjitlink-cu12 = "^12.3.52" prettytable = "^3.9.0" docopt = "^0.6.2" bitsandbytes = "0.41.3.post2" +openai = "^1.6.1" [build-system] requires = ["poetry-core"] From 359195478571fc8e6126b32efe7567cc8b1bbbb6 Mon Sep 17 00:00:00 2001 From: qidanrui Date: Thu, 28 Dec 2023 08:54:45 +0000 Subject: [PATCH 3/3] reformat code with black --- dbgpt_hub/data_generator/gpt_generator.py | 250 ++++++++++-------- dbgpt_hub/data_generator/gpt_generator_api.py | 23 +- dbgpt_hub/data_generator/llm_generator.py | 1 + dbgpt_hub/data_generator/utils.py | 2 +- 4 files changed, 149 insertions(+), 127 deletions(-) diff --git a/dbgpt_hub/data_generator/gpt_generator.py b/dbgpt_hub/data_generator/gpt_generator.py index 82e5c59..3023791 100644 --- a/dbgpt_hub/data_generator/gpt_generator.py +++ b/dbgpt_hub/data_generator/gpt_generator.py @@ -5,121 +5,139 @@ from tqdm import tqdm -from llm_generator import LLMGenerator -from utils import COT_PROMPT, FEW_SHOTS_EXAMPLE +from .llm_generator import LLMGenerator +from .utils import COT_PROMPT, FEW_SHOTS_EXAMPLE -class GPTGenerator(LLMGenerator): - - def __init__( - self, - model: str = "gpt-3.5-turbo-16k", - model_temperature: int = 0, - max_tokens: int = 2048, - prompt: str = "", - num_text2sql_pair_each_db: int = 10, - table_file_path: str = "", - output_path: str = "" - ): - if len(table_file_path) > 0: - self.table_file_path = table_file_path - else: - self.table_file_path = "../data/spider/tables.json" - - if len(output_path) > 0: - self.output_path = output_path - else: - self.output_path = "../data/spider/synthetic_data_with_gpt.json" - - if len(prompt) > 0: - self.prompt = prompt - else: - self.prompt = COT_PROMPT - self.model = model - self.model_temperature = model_temperature - self.max_tokens = max_tokens - self.num_text2sql_pair_each_db = num_text2sql_pair_each_db - - self.synthetic_dataset = [] - - def generate_synthetic_dataset(self): - """Function for generating synthetic dataset. - By default, we generate Spider-like synthetic dataset. - """ - schema = "" - synthetic_dataset = [] - - tables = json.load(open(self.table_file_path)) - db_num = len(tables) - easy_count = int(self.num_text2sql_pair_each_db / db_num) - medium_count = int(self.num_text2sql_pair_each_db / db_num) - hard_count = self.num_text2sql_pair_each_db - easy_count - medium_count - - db_dict = {} - for item in tqdm(tables[:]): - tables = item["table_names_original"] - coloumns = item["column_names_original"][1:] - primary_key = item["primary_keys"] - foreign_keys = item["foreign_keys"] - schema = item["db_id"] + " database contains tables such as " + ", ".join(tables) + ". " - for i, name in enumerate(tables): - data = [coloumn[1] for coloumn in coloumns if coloumn[0] == i] - schema += "Table " + name + " has columns such as " + ", ".join(data) + ". " - - # get primary key info - for j in range(len(primary_key)): - if coloumns[primary_key[j]-1][0] == i: - schema += coloumns[primary_key[j]-1][1] + " is the primary key." + "\n" - - # get foreign key info - for key in foreign_keys: - schema += "The " + coloumns[key[0]-1][1] + " of " + tables[coloumns[key[0]-1][0]] + \ - " is the foreign key of " + coloumns[key[1]-1][1] + \ - " of " + tables[coloumns[key[1]-1][0]] + ".\n" - - db_dict[item["db_id"]] = schema - - try: - # Single generated data for one DB - for k in range(self.num_text2sql_pair_each_db): - text2sql_pair = self._chat_llm( - self.prompt.format( - easy_count=easy_count, - medium_count=medium_count, - hard_count=hard_count, - schema=schema, - few_shots_example=FEW_SHOTS_EXAMPLE) - ) - text2sql_pair = eval(text2sql_pair) - synthetic_dataset += text2sql_pair - except: - continue - - self.synthetic_dataset = synthetic_dataset - self._writeout_dataset() - - def _chat_llm(self, prompt): - client = OpenAI( - api_key=os.environ['OPENAI_API_KEY'], - ) - - completion = client.chat.completions.create( - messages=[ - { - "role": "user", - "content": prompt, - } - ], - model=self.model, - temperature=self.model_temperature, - max_tokens=self.max_tokens, - top_p=1, - frequency_penalty=0, - presence_penalty=0, - ) - return(completion.choices[0].message.content) - - def _writeout_dataset(self): - with open(self.output_path, "w", encoding="utf-8") as s: - json.dump(self.synthetic_dataset, s, indent=4, ensure_ascii=False) - +class GPTGenerator(LLMGenerator): + def __init__( + self, + model: str = "gpt-3.5-turbo-16k", + model_temperature: int = 0, + max_tokens: int = 2048, + prompt: str = "", + num_text2sql_pair_each_db: int = 10, + table_file_path: str = "", + output_path: str = "", + ): + if len(table_file_path) > 0: + self.table_file_path = table_file_path + else: + self.table_file_path = "../data/spider/tables.json" + + if len(output_path) > 0: + self.output_path = output_path + else: + self.output_path = "../data/spider/synthetic_data_with_gpt.json" + + if len(prompt) > 0: + self.prompt = prompt + else: + self.prompt = COT_PROMPT + self.model = model + self.model_temperature = model_temperature + self.max_tokens = max_tokens + self.num_text2sql_pair_each_db = num_text2sql_pair_each_db + + self.synthetic_dataset = [] + + def generate_synthetic_dataset(self): + """Function for generating synthetic dataset. + By default, we generate Spider-like synthetic dataset. + """ + schema = "" + synthetic_dataset = [] + + tables = json.load(open(self.table_file_path)) + db_num = len(tables) + easy_count = int(self.num_text2sql_pair_each_db / db_num) + medium_count = int(self.num_text2sql_pair_each_db / db_num) + hard_count = self.num_text2sql_pair_each_db - easy_count - medium_count + + db_dict = {} + for item in tqdm(tables[:]): + tables = item["table_names_original"] + coloumns = item["column_names_original"][1:] + primary_key = item["primary_keys"] + foreign_keys = item["foreign_keys"] + schema = ( + item["db_id"] + + " database contains tables such as " + + ", ".join(tables) + + ". " + ) + for i, name in enumerate(tables): + data = [coloumn[1] for coloumn in coloumns if coloumn[0] == i] + schema += ( + "Table " + name + " has columns such as " + ", ".join(data) + ". " + ) + + # get primary key info + for j in range(len(primary_key)): + if coloumns[primary_key[j] - 1][0] == i: + schema += ( + coloumns[primary_key[j] - 1][1] + + " is the primary key." + + "\n" + ) + + # get foreign key info + for key in foreign_keys: + schema += ( + "The " + + coloumns[key[0] - 1][1] + + " of " + + tables[coloumns[key[0] - 1][0]] + + " is the foreign key of " + + coloumns[key[1] - 1][1] + + " of " + + tables[coloumns[key[1] - 1][0]] + + ".\n" + ) + + db_dict[item["db_id"]] = schema + + try: + # Single generated data for one DB + for k in range(self.num_text2sql_pair_each_db): + text2sql_pair = self._chat_llm( + self.prompt.format( + easy_count=easy_count, + medium_count=medium_count, + hard_count=hard_count, + schema=schema, + few_shots_example=FEW_SHOTS_EXAMPLE, + ) + ) + text2sql_pair = eval(text2sql_pair) + synthetic_dataset += text2sql_pair + except: + continue + + self.synthetic_dataset = synthetic_dataset + self._writeout_dataset() + + def _chat_llm(self, prompt): + client = OpenAI( + api_key=os.environ["OPENAI_API_KEY"], + ) + + completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + model=self.model, + temperature=self.model_temperature, + max_tokens=self.max_tokens, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + return completion.choices[0].message.content + + def _writeout_dataset(self): + with open(self.output_path, "w", encoding="utf-8") as s: + json.dump(self.synthetic_dataset, s, indent=4, ensure_ascii=False) diff --git a/dbgpt_hub/data_generator/gpt_generator_api.py b/dbgpt_hub/data_generator/gpt_generator_api.py index 5cb9f4e..d34200f 100644 --- a/dbgpt_hub/data_generator/gpt_generator_api.py +++ b/dbgpt_hub/data_generator/gpt_generator_api.py @@ -1,25 +1,27 @@ - import os import sys from typing import Optional, Dict, Any -from gpt_generator import GPTGenerator -from utils import COT_PROMPT +from .gpt_generator import GPTGenerator +from .utils import COT_PROMPT ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -def generate_dataset_with_gpt( - args: Optional[Dict[str, Any]] = None -): + +def generate_dataset_with_gpt(args: Optional[Dict[str, Any]] = None): # Default Arguments if args is None: args = { - "model": "gpt-3.5-turbo-16k", + "model": "gpt-3.5-turbo-16k", "prompt": COT_PROMPT, "num_text2sql_pair_each_db": 1, - "table_file_path": os.path.join(ROOT_PATH, "dbgpt_hub/data/spider/tables.json"), - "output_path": os.path.join(ROOT_PATH, "dbgpt_hub/data/spider/synthetic_data_with_gpt.json") + "table_file_path": os.path.join( + ROOT_PATH, "dbgpt_hub/data/spider/tables.json" + ), + "output_path": os.path.join( + ROOT_PATH, "dbgpt_hub/data/spider/synthetic_data_with_gpt.json" + ), } else: args = args @@ -30,9 +32,10 @@ def generate_dataset_with_gpt( prompt=args["prompt"], num_text2sql_pair_each_db=args["num_text2sql_pair_each_db"], table_file_path=args["table_file_path"], - output_path=args["output_path"] + output_path=args["output_path"], ) gpt_generator.generate_synthetic_dataset() + if __name__ == "__main__": generate_dataset_with_gpt() diff --git a/dbgpt_hub/data_generator/llm_generator.py b/dbgpt_hub/data_generator/llm_generator.py index eb18a74..783acd3 100644 --- a/dbgpt_hub/data_generator/llm_generator.py +++ b/dbgpt_hub/data_generator/llm_generator.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union + class LLMGenerator(ABC): """An interface for large language model data generator. A LLM data generator can accept prompts and generate synthetic Text2SQL dataset. diff --git a/dbgpt_hub/data_generator/utils.py b/dbgpt_hub/data_generator/utils.py index c4365a8..bfb6697 100644 --- a/dbgpt_hub/data_generator/utils.py +++ b/dbgpt_hub/data_generator/utils.py @@ -32,4 +32,4 @@ 'query': 'SELECT DISTINCT t3.policy_type_code FROM customers AS t1 JOIN customers_policies AS t2 ON t1.customer_id=t2.customer_id JOIN available_policies AS t3 ON t2.policy_id=t3.policy_id WHERE t1.customer_name="Dayana Robel"' } ] -""" \ No newline at end of file +"""