diff --git a/dbgpt_hub/data_generator/__init__.py b/dbgpt_hub/data_generator/__init__.py new file mode 100644 index 0000000..dee6f5c --- /dev/null +++ b/dbgpt_hub/data_generator/__init__.py @@ -0,0 +1,8 @@ +""" +dbgpt_hub.data_generator +============== +""" + +from .gpt_generator_api import generate_dataset_with_gpt + +__all__ = ["generate_dataset_with_gpt"] diff --git a/dbgpt_hub/data_generator/gpt_generator.py b/dbgpt_hub/data_generator/gpt_generator.py new file mode 100644 index 0000000..3023791 --- /dev/null +++ b/dbgpt_hub/data_generator/gpt_generator.py @@ -0,0 +1,143 @@ +from openai import OpenAI + +import os +import json + +from tqdm import tqdm + +from .llm_generator import LLMGenerator +from .utils import COT_PROMPT, FEW_SHOTS_EXAMPLE + + +class GPTGenerator(LLMGenerator): + def __init__( + self, + model: str = "gpt-3.5-turbo-16k", + model_temperature: int = 0, + max_tokens: int = 2048, + prompt: str = "", + num_text2sql_pair_each_db: int = 10, + table_file_path: str = "", + output_path: str = "", + ): + if len(table_file_path) > 0: + self.table_file_path = table_file_path + else: + self.table_file_path = "../data/spider/tables.json" + + if len(output_path) > 0: + self.output_path = output_path + else: + self.output_path = "../data/spider/synthetic_data_with_gpt.json" + + if len(prompt) > 0: + self.prompt = prompt + else: + self.prompt = COT_PROMPT + self.model = model + self.model_temperature = model_temperature + self.max_tokens = max_tokens + self.num_text2sql_pair_each_db = num_text2sql_pair_each_db + + self.synthetic_dataset = [] + + def generate_synthetic_dataset(self): + """Function for generating synthetic dataset. + By default, we generate Spider-like synthetic dataset. + """ + schema = "" + synthetic_dataset = [] + + tables = json.load(open(self.table_file_path)) + db_num = len(tables) + easy_count = int(self.num_text2sql_pair_each_db / db_num) + medium_count = int(self.num_text2sql_pair_each_db / db_num) + hard_count = self.num_text2sql_pair_each_db - easy_count - medium_count + + db_dict = {} + for item in tqdm(tables[:]): + tables = item["table_names_original"] + coloumns = item["column_names_original"][1:] + primary_key = item["primary_keys"] + foreign_keys = item["foreign_keys"] + schema = ( + item["db_id"] + + " database contains tables such as " + + ", ".join(tables) + + ". " + ) + for i, name in enumerate(tables): + data = [coloumn[1] for coloumn in coloumns if coloumn[0] == i] + schema += ( + "Table " + name + " has columns such as " + ", ".join(data) + ". " + ) + + # get primary key info + for j in range(len(primary_key)): + if coloumns[primary_key[j] - 1][0] == i: + schema += ( + coloumns[primary_key[j] - 1][1] + + " is the primary key." + + "\n" + ) + + # get foreign key info + for key in foreign_keys: + schema += ( + "The " + + coloumns[key[0] - 1][1] + + " of " + + tables[coloumns[key[0] - 1][0]] + + " is the foreign key of " + + coloumns[key[1] - 1][1] + + " of " + + tables[coloumns[key[1] - 1][0]] + + ".\n" + ) + + db_dict[item["db_id"]] = schema + + try: + # Single generated data for one DB + for k in range(self.num_text2sql_pair_each_db): + text2sql_pair = self._chat_llm( + self.prompt.format( + easy_count=easy_count, + medium_count=medium_count, + hard_count=hard_count, + schema=schema, + few_shots_example=FEW_SHOTS_EXAMPLE, + ) + ) + text2sql_pair = eval(text2sql_pair) + synthetic_dataset += text2sql_pair + except: + continue + + self.synthetic_dataset = synthetic_dataset + self._writeout_dataset() + + def _chat_llm(self, prompt): + client = OpenAI( + api_key=os.environ["OPENAI_API_KEY"], + ) + + completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + model=self.model, + temperature=self.model_temperature, + max_tokens=self.max_tokens, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + ) + return completion.choices[0].message.content + + def _writeout_dataset(self): + with open(self.output_path, "w", encoding="utf-8") as s: + json.dump(self.synthetic_dataset, s, indent=4, ensure_ascii=False) diff --git a/dbgpt_hub/data_generator/gpt_generator_api.py b/dbgpt_hub/data_generator/gpt_generator_api.py new file mode 100644 index 0000000..d34200f --- /dev/null +++ b/dbgpt_hub/data_generator/gpt_generator_api.py @@ -0,0 +1,41 @@ +import os +import sys + +from typing import Optional, Dict, Any +from .gpt_generator import GPTGenerator +from .utils import COT_PROMPT + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + + +def generate_dataset_with_gpt(args: Optional[Dict[str, Any]] = None): + # Default Arguments + if args is None: + args = { + "model": "gpt-3.5-turbo-16k", + "prompt": COT_PROMPT, + "num_text2sql_pair_each_db": 1, + "table_file_path": os.path.join( + ROOT_PATH, "dbgpt_hub/data/spider/tables.json" + ), + "output_path": os.path.join( + ROOT_PATH, "dbgpt_hub/data/spider/synthetic_data_with_gpt.json" + ), + } + else: + args = args + + # Run GPT Generator + gpt_generator = GPTGenerator( + model=args["model"], + prompt=args["prompt"], + num_text2sql_pair_each_db=args["num_text2sql_pair_each_db"], + table_file_path=args["table_file_path"], + output_path=args["output_path"], + ) + gpt_generator.generate_synthetic_dataset() + + +if __name__ == "__main__": + generate_dataset_with_gpt() diff --git a/dbgpt_hub/data_generator/llm_generator.py b/dbgpt_hub/data_generator/llm_generator.py new file mode 100644 index 0000000..783acd3 --- /dev/null +++ b/dbgpt_hub/data_generator/llm_generator.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod + +from typing import Any, Dict, List, Optional, Tuple, Union + + +class LLMGenerator(ABC): + """An interface for large language model data generator. + A LLM data generator can accept prompts and generate synthetic Text2SQL dataset. + """ + + @abstractmethod + def generate_synthetic_dataset(self): + """Function for generating synthetic dataset""" + pass + + @abstractmethod + def _chat_llm(self): + """Function for interacting with LLMs""" + pass + + @abstractmethod + def _writeout_dataset(self): + """Function for writing out generated dataset""" + pass diff --git a/dbgpt_hub/data_generator/utils.py b/dbgpt_hub/data_generator/utils.py new file mode 100644 index 0000000..bfb6697 --- /dev/null +++ b/dbgpt_hub/data_generator/utils.py @@ -0,0 +1,35 @@ +COT_PROMPT = """ +Please use the following database information, genarate different difficulty level of natural language questions with their corresponding SQL querires. +There are three different difficulty levels: +When generating natual language questions and their corresponding SQL queries, you should consider different SQL operators such as WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, JOIN, INTERSECT, EXCEPT, UNION, NOT IN, OR, AND, EXISTS, LIKE as well as nested queries. +Moreover, please make sure that each table in the database appears in at least one query. + +There are three different difficulty levels, which are defined as follows: +Easy: Queries that require basic filtering or aggregation on a single table. +Medium: Queries that encompass more complex filtering or aggregation and involve joining multiple tables. +Hard: Queries that entail advanced filtering or aggregation, multiple joins, and the use of subqueries. + +Here is the basic information of database: {schema} + +Based on the tables, columns, primary keys, foreign keys and different difficulty levels, generate {easy_count} Easy, {medium_count} Medium, and {hard_count} Hard natural language questions with their correlated SQL queries. +Provide your answer in JSON form. Reply with only the answer in JSON form and include no other commentary: +RESPONSE FORMAT: +{few_shots_example} + +The "db_id" in the above examples means the name of used database. Do not fill out it with "_database" suffix. +""" + +FEW_SHOTS_EXAMPLE = """ +[ + { + 'db_id': 'music_2', + 'question': 'Who performed the song named "Le Pop"?', + 'query': 'SELECT T2.firstname, T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1.bandmate=T2.id JOIN Songs AS T3 ON T3.SongId=T1.SongId WHERE T3.Title="Le Pop"' + }, + { + 'db_id': 'insurance_fnol', + 'question': 'Tell me the types of the policy used by the customer named "Dayana Robel".', + 'query': 'SELECT DISTINCT t3.policy_type_code FROM customers AS t1 JOIN customers_policies AS t2 ON t1.customer_id=t2.customer_id JOIN available_policies AS t3 ON t2.policy_id=t3.policy_id WHERE t1.customer_name="Dayana Robel"' + } +] +""" diff --git a/poetry.lock b/poetry.lock index a9607d0..90a98ef 100644 --- a/poetry.lock +++ b/poetry.lock @@ -616,6 +616,17 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "docker-pycreds" version = "0.4.0" @@ -1946,6 +1957,29 @@ files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "openai" +version = "1.6.1" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.6.1-py3-none-any.whl", hash = "sha256:bc9f774838d67ac29fb24cdeb2d58faf57de8b311085dcd1348f7aa02a96c7ee"}, + {file = "openai-1.6.1.tar.gz", hash = "sha256:d553ca9dbf9486b08e75b09e8671e4f638462aaadccfced632bf490fc3d75fa2"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "orjson" version = "3.9.10" diff --git a/pyproject.toml b/pyproject.toml index 44fd75f..e07d0f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dbgpt_hub" -version = "0.3.0" +version = "0.3.1" description = "DB-GPT-Hub: Text-to-SQL parsing with LLMs" authors = ["Your Name "] license = "MIT" @@ -60,6 +60,7 @@ nvidia-nvjitlink-cu12 = "^12.3.52" prettytable = "^3.9.0" docopt = "^0.6.2" bitsandbytes = "0.41.3.post2" +openai = "^1.6.1" [build-system] requires = ["poetry-core"]