Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add schema-linking awel example #1081

Merged
merged 3 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions dbgpt/rag/operator/schema_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any, Optional

from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
from dbgpt.storage.vector_store.connector import VectorStoreConnector


class SchemaLinkingOperator(MapOperator[Any, Any]):
"""The Schema Linking Operator."""

def __init__(
self,
top_k: int = 5,
connection: Optional[RDBMSDatabase] = None,
llm: Optional[LLMClient] = None,
model_name: Optional[str] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
"""Init the schema linking operator
Args:
connection (RDBMSDatabase): The connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)

self._schema_linking = SchemaLinking(
top_k=top_k,
connection=connection,
llm=llm,
model_name=model_name,
vector_store_connector=vector_store_connector,
)

async def map(self, query: str) -> str:
"""retrieve table schemas.
Args:
query (str): query.
Return:
str: schema info
"""
return str(await self._schema_linking.schema_linking_with_llm(query))
Empty file.
60 changes: 60 additions & 0 deletions dbgpt/rag/schemalinker/base_linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import List


class BaseSchemaLinker(ABC):
"""Base Linker."""

def schema_linking(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
return self._schema_linking(query)

def schema_linking_with_vector_db(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
return self._schema_linking_with_vector_db(query)

async def schema_linking_with_llm(self, query: str) -> List:
""" "
Args:
query(str): query text
Returns:
List: list of schema
"""
return await self._schema_linking_with_llm(query)

@abstractmethod
def _schema_linking(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""

@abstractmethod
def _schema_linking_with_vector_db(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""

@abstractmethod
async def _schema_linking_with_llm(self, query: str) -> List:
"""
Args:
query (str): query text
Returns:
List: list of schema
"""
78 changes: 78 additions & 0 deletions dbgpt/rag/schemalinker/schema_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from functools import reduce
from typing import List, Optional

from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks

INSTRUCTION = (
"You need to filter out the most relevant database table schema information (it may be a single "
"table or multiple tables) required to generate the SQL of the question query from the given "
"database schema information. First, I will show you an example of an instruction followed by "
"the correct schema response. Then, I will give you a new instruction, and you should write "
"the schema response that appropriately completes the request.\n### Example1 Instruction:\n"
"['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)']\n### Example1 "
"Input:\nFind the age of student table\n### Example1 Response:\n['student(id, name, age, info)']"
"\n###New Instruction:\n{}"
)
INPUT_PROMPT = "\n###New Input:\n{}\n###New Response:"


class SchemaLinking(BaseSchemaLinker):
"""SchemaLinking by LLM"""

def __init__(
self,
top_k: int = 5,
connection: Optional[RDBMSDatabase] = None,
llm: Optional[LLMClient] = None,
model_name: Optional[str] = None,
vector_store_connector: Optional[VectorStoreConnector] = None,
**kwargs
):
"""
Args:
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
llm (Optional[LLMClient]): base llm
"""
super().__init__(**kwargs)
self._top_k = top_k
self._connection = connection
self._llm = llm
self._model_name = model_name
self._vector_store_connector = vector_store_connector

def _schema_linking(self, query: str) -> List:
"""get all db schema info"""
table_summaries = _parse_db_summary(self._connection)
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
chunks_content = [chunk.content for chunk in chunks]
return chunks_content

def _schema_linking_with_vector_db(self, query: str) -> List:
queries = [query]
candidates = [
self._vector_store_connector.similar_search(query, self._top_k)
for query in queries
]
candidates = reduce(lambda x, y: x + y, candidates)
return candidates

async def _schema_linking_with_llm(self, query: str) -> List:
chunks_content = self.schema_linking(query)
schema_prompt = INSTRUCTION.format(
str(chunks_content) + INPUT_PROMPT.format(query)
)
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=schema_prompt)
]
request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm.generate(request)]
# get accurate schem info by llm
schema = await run_async_tasks(tasks=tasks, concurrency_limit=1)
schema_text = schema[0].text
return schema_text
Loading