Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpt synthetic dataset generator and update poetry #198

Merged
merged 4 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dbgpt_hub/data_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
dbgpt_hub.data_generator
==============
"""

from .gpt_generator_api import generate_dataset_with_gpt

__all__ = ["generate_dataset_with_gpt"]
143 changes: 143 additions & 0 deletions dbgpt_hub/data_generator/gpt_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from openai import OpenAI

import os
import json

from tqdm import tqdm

from .llm_generator import LLMGenerator
from .utils import COT_PROMPT, FEW_SHOTS_EXAMPLE


class GPTGenerator(LLMGenerator):
def __init__(
self,
model: str = "gpt-3.5-turbo-16k",
model_temperature: int = 0,
max_tokens: int = 2048,
prompt: str = "",
num_text2sql_pair_each_db: int = 10,
table_file_path: str = "",
output_path: str = "",
):
if len(table_file_path) > 0:
self.table_file_path = table_file_path
else:
self.table_file_path = "../data/spider/tables.json"

if len(output_path) > 0:
self.output_path = output_path
else:
self.output_path = "../data/spider/synthetic_data_with_gpt.json"

if len(prompt) > 0:
self.prompt = prompt
else:
self.prompt = COT_PROMPT
self.model = model
self.model_temperature = model_temperature
self.max_tokens = max_tokens
self.num_text2sql_pair_each_db = num_text2sql_pair_each_db

self.synthetic_dataset = []

def generate_synthetic_dataset(self):
"""Function for generating synthetic dataset.
By default, we generate Spider-like synthetic dataset.
"""
schema = ""
synthetic_dataset = []

tables = json.load(open(self.table_file_path))
db_num = len(tables)
easy_count = int(self.num_text2sql_pair_each_db / db_num)
medium_count = int(self.num_text2sql_pair_each_db / db_num)
hard_count = self.num_text2sql_pair_each_db - easy_count - medium_count

db_dict = {}
for item in tqdm(tables[:]):
tables = item["table_names_original"]
coloumns = item["column_names_original"][1:]
primary_key = item["primary_keys"]
foreign_keys = item["foreign_keys"]
schema = (
item["db_id"]
+ " database contains tables such as "
+ ", ".join(tables)
+ ". "
)
for i, name in enumerate(tables):
data = [coloumn[1] for coloumn in coloumns if coloumn[0] == i]
schema += (
"Table " + name + " has columns such as " + ", ".join(data) + ". "
)

# get primary key info
for j in range(len(primary_key)):
if coloumns[primary_key[j] - 1][0] == i:
schema += (
coloumns[primary_key[j] - 1][1]
+ " is the primary key."
+ "\n"
)

# get foreign key info
for key in foreign_keys:
schema += (
"The "
+ coloumns[key[0] - 1][1]
+ " of "
+ tables[coloumns[key[0] - 1][0]]
+ " is the foreign key of "
+ coloumns[key[1] - 1][1]
+ " of "
+ tables[coloumns[key[1] - 1][0]]
+ ".\n"
)

db_dict[item["db_id"]] = schema

try:
# Single generated data for one DB
for k in range(self.num_text2sql_pair_each_db):
text2sql_pair = self._chat_llm(
self.prompt.format(
easy_count=easy_count,
medium_count=medium_count,
hard_count=hard_count,
schema=schema,
few_shots_example=FEW_SHOTS_EXAMPLE,
)
)
text2sql_pair = eval(text2sql_pair)
synthetic_dataset += text2sql_pair
except:
continue

self.synthetic_dataset = synthetic_dataset
self._writeout_dataset()

def _chat_llm(self, prompt):
client = OpenAI(
api_key=os.environ["OPENAI_API_KEY"],
)

completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=self.model,
temperature=self.model_temperature,
max_tokens=self.max_tokens,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
)
return completion.choices[0].message.content

def _writeout_dataset(self):
with open(self.output_path, "w", encoding="utf-8") as s:
json.dump(self.synthetic_dataset, s, indent=4, ensure_ascii=False)
41 changes: 41 additions & 0 deletions dbgpt_hub/data_generator/gpt_generator_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import sys

from typing import Optional, Dict, Any
from .gpt_generator import GPTGenerator
from .utils import COT_PROMPT

ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)


def generate_dataset_with_gpt(args: Optional[Dict[str, Any]] = None):
# Default Arguments
if args is None:
args = {
"model": "gpt-3.5-turbo-16k",
"prompt": COT_PROMPT,
"num_text2sql_pair_each_db": 1,
"table_file_path": os.path.join(
ROOT_PATH, "dbgpt_hub/data/spider/tables.json"
),
"output_path": os.path.join(
ROOT_PATH, "dbgpt_hub/data/spider/synthetic_data_with_gpt.json"
),
}
else:
args = args

# Run GPT Generator
gpt_generator = GPTGenerator(
model=args["model"],
prompt=args["prompt"],
num_text2sql_pair_each_db=args["num_text2sql_pair_each_db"],
table_file_path=args["table_file_path"],
output_path=args["output_path"],
)
gpt_generator.generate_synthetic_dataset()


if __name__ == "__main__":
generate_dataset_with_gpt()
24 changes: 24 additions & 0 deletions dbgpt_hub/data_generator/llm_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod

from typing import Any, Dict, List, Optional, Tuple, Union


class LLMGenerator(ABC):
"""An interface for large language model data generator.
A LLM data generator can accept prompts and generate synthetic Text2SQL dataset.
"""

@abstractmethod
def generate_synthetic_dataset(self):
"""Function for generating synthetic dataset"""
pass

@abstractmethod
def _chat_llm(self):
"""Function for interacting with LLMs"""
pass

@abstractmethod
def _writeout_dataset(self):
"""Function for writing out generated dataset"""
pass
35 changes: 35 additions & 0 deletions dbgpt_hub/data_generator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
COT_PROMPT = """
Please use the following database information, genarate different difficulty level of natural language questions with their corresponding SQL querires.
There are three different difficulty levels:
When generating natual language questions and their corresponding SQL queries, you should consider different SQL operators such as WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, JOIN, INTERSECT, EXCEPT, UNION, NOT IN, OR, AND, EXISTS, LIKE as well as nested queries.
Moreover, please make sure that each table in the database appears in at least one query.

There are three different difficulty levels, which are defined as follows:
Easy: Queries that require basic filtering or aggregation on a single table.
Medium: Queries that encompass more complex filtering or aggregation and involve joining multiple tables.
Hard: Queries that entail advanced filtering or aggregation, multiple joins, and the use of subqueries.

Here is the basic information of database: {schema}

Based on the tables, columns, primary keys, foreign keys and different difficulty levels, generate {easy_count} Easy, {medium_count} Medium, and {hard_count} Hard natural language questions with their correlated SQL queries.
Provide your answer in JSON form. Reply with only the answer in JSON form and include no other commentary:
RESPONSE FORMAT:
{few_shots_example}

The "db_id" in the above examples means the name of used database. Do not fill out it with "_database" suffix.
"""

FEW_SHOTS_EXAMPLE = """
[
{
'db_id': 'music_2',
'question': 'Who performed the song named "Le Pop"?',
'query': 'SELECT T2.firstname, T2.lastname FROM Performance AS T1 JOIN Band AS T2 ON T1.bandmate=T2.id JOIN Songs AS T3 ON T3.SongId=T1.SongId WHERE T3.Title="Le Pop"'
},
{
'db_id': 'insurance_fnol',
'question': 'Tell me the types of the policy used by the customer named "Dayana Robel".',
'query': 'SELECT DISTINCT t3.policy_type_code FROM customers AS t1 JOIN customers_policies AS t2 ON t1.customer_id=t2.customer_id JOIN available_policies AS t3 ON t2.policy_id=t3.policy_id WHERE t1.customer_name="Dayana Robel"'
}
]
"""
34 changes: 34 additions & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dbgpt_hub"
version = "0.3.0"
version = "0.3.1"
description = "DB-GPT-Hub: Text-to-SQL parsing with LLMs"
authors = ["Your Name <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -60,6 +60,7 @@ nvidia-nvjitlink-cu12 = "^12.3.52"
prettytable = "^3.9.0"
docopt = "^0.6.2"
bitsandbytes = "0.41.3.post2"
openai = "^1.6.1"

[build-system]
requires = ["poetry-core"]
Expand Down
Loading