Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add schema-linking awel example #1081

Merged
merged 3 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions dbgpt/rag/operator/chart_draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any, Optional

from dbgpt.core.awel import MapOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.schemalinker.chart_draw import ChartDraw


class ChartDrawOperator(MapOperator[Any, Any]):
"""The Chart Draw Operator."""

def __init__(self, connection: Optional[RDBMSDatabase] = None, **kwargs):
"""
Args:
connection (RDBMSDatabase): The connection.
"""
super().__init__(**kwargs)
self._draw_chart = ChartDraw(connection=connection)

def map(self, sql: str) -> str:
"""get sql result in db and draw.
Args:
sql (str): str.
"""
return self._draw_chart.chart_draw(sql=sql)
36 changes: 36 additions & 0 deletions dbgpt/rag/operator/schema_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Optional

from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking


class SchemaLinkingOperator(MapOperator[Any, Any]):
"""The Schema Linking Operator."""

def __init__(
self,
connection: Optional[RDBMSDatabase] = None,
llm: Optional[LLMClient] = None,
**kwargs
):
"""Init the schema linking operator
Args:
connection (RDBMSDatabase): The connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._schema_linking = SchemaLinking(
connection=connection,
llm=llm,
)

async def map(self, query: str) -> str:
"""retrieve table schemas.
Args:
query (str): query.
Return:
str: schema info
"""
return await self._schema_linking.schema_linking(query)
27 changes: 27 additions & 0 deletions dbgpt/rag/operator/sql_exec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any, Optional

from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.task.base import IN
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.schemalinker.sql_exec import SqlExec


class SqlExecOperator(MapOperator[Any, Any]):
"""The Sql Execution Operator."""

def __init__(self, connection: Optional[RDBMSDatabase] = None, **kwargs):
"""
Args:
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection
"""
super().__init__(**kwargs)
self._sql_exec = SqlExec(connection=connection)

def map(self, sql: str) -> str:
"""retrieve table schemas.
Args:
sql (str): query.
Return:
str: sql execution
"""
return self._sql_exec.sql_exec(sql=sql)
28 changes: 28 additions & 0 deletions dbgpt/rag/operator/sql_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Optional

from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.rag.schemalinker.sql_gen import SqlGen


class SqlGenOperator(MapOperator[Any, Any]):
"""The Sql Generation Operator."""

def __init__(self, llm: Optional[LLMClient], **kwargs):
"""Init the sql generation operator
Args:
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._sql_gen = SqlGen(llm=llm)

async def map(self, prompt_with_query_and_schema: str) -> str:
"""generate sql by llm.
Args:
prompt_with_query_and_schema (str): prompt
Return:
str: sql
"""
return await self._sql_gen.sql_gen(
prompt_with_query_and_schema=prompt_with_query_and_schema
)
Empty file.
34 changes: 34 additions & 0 deletions dbgpt/rag/schemalinker/chart_draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

from dbgpt.datasource.rdbms.base import RDBMSDatabase


class ChartDraw:
"""Chart Draw"""

def __init__(self, connection: Optional[RDBMSDatabase] = None, **kwargs):
"""
Args:
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection
"""
super().__init__(**kwargs)
self._connection = connection

def chart_draw(self, sql: str) -> str:
"""get chart data and draw by matplotlib
Args:
sql (str): sql text
"""
# df: (Pandas) DataFrame
df = self._connection.run_to_df(command=sql, fetch="all")
# draw chart
import matplotlib.pyplot as plt

category_column = df.columns[0]
count_column = df.columns[1]
plt.figure(figsize=(8, 4))
plt.bar(df[category_column], df[count_column])
plt.xlabel(category_column)
plt.ylabel(count_column)
plt.show()
return str(df)
62 changes: 62 additions & 0 deletions dbgpt/rag/schemalinker/schema_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Optional

from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.util.chat_util import run_async_tasks

INSTRUCTION = (
"You need to filter out the most relevant database table schema information (it may be a single "
"table or multiple tables) required to generate the SQL of the question query from the given "
"database schema information. First, I will show you an example of an instruction followed by "
"the correct schema response. Then, I will give you a new instruction, and you should write "
"the schema response that appropriately completes the request.\n### Example1 Instruction:\n"
"['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)']\n### Example1 "
"Input:\nFind the age of student table\n### Example1 Response:\n['student(id, name, age, info)']"
"\n###New Instruction:\n{}"
)
INPUT_PROMPT = "\n###New Input:\n{}\n###New Response:"


class SchemaLinking:
"""SchemaLinking by LLM"""

def __init__(
self,
connection: Optional[RDBMSDatabase] = None,
llm: Optional[LLMClient] = None,
**kwargs
):
"""
Args:
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._connection = connection
self._llm = llm

async def schema_linking(self, query: str) -> str:
"""get db schema info by llm.
Args:
query (str): query text
Return:
str : schema info
"""
# get all db schema info
table_summaries = _parse_db_summary(self._connection)
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
chunks_content = [chunk.content for chunk in chunks]
schema_prompt = INSTRUCTION.format(
str(chunks_content) + INPUT_PROMPT.format(query)
)
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=schema_prompt)
]
request = ModelRequest(model="gpt-3.5-turbo", messages=messages)
tasks = [self._llm.generate(request)]
# get accurate schem info by llm
schema = await run_async_tasks(tasks=tasks, concurrency_limit=1)
schema_text = str(schema[0].text)
return schema_text
25 changes: 25 additions & 0 deletions dbgpt/rag/schemalinker/sql_exec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Optional

from dbgpt.datasource.rdbms.base import RDBMSDatabase


class SqlExec:
"""Sql execution"""

def __init__(self, connection: Optional[RDBMSDatabase] = None, **kwargs):
""" "
Args:
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection
"""
super().__init__(**kwargs)
self._connection = connection

def sql_exec(self, sql: str) -> str:
"""sql execution in database
Args:
sql (str): query text
Return:
str: sql result in database
"""
res = self._connection._query(query=sql, fetch="all")
return str(res)
34 changes: 34 additions & 0 deletions dbgpt/rag/schemalinker/sql_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
from dbgpt.util.chat_util import run_async_tasks


class SqlGen:
"""Sql generation"""

def __init__(self, llm: Optional[LLMClient] = None, **kwargs):
"""
Args:
llm (Optional[LLMClient]): base LLM
"""
super().__init__(**kwargs)
self._llm = llm

async def sql_gen(self, prompt_with_query_and_schema: str) -> str:
"""sql generation by llm.
Args:
prompt_with_query_and_schema (str): prompt text
Return:
str: sql
"""
messages = [
ModelMessage(
role=ModelMessageRoleType.SYSTEM, content=prompt_with_query_and_schema
)
]
request = ModelRequest(model="gpt-3.5-turbo", messages=messages)
tasks = [self._llm.generate(request)]
output = await run_async_tasks(tasks=tasks, concurrency_limit=1)
sql = output[0].text
return sql
Loading