Skip to content

Commit 298bdcf

Browse files
junewglHopshine
authored andcommitted
feat: add schema-linking awel example (eosphoros-ai#1081)
1 parent d8d5fd1 commit 298bdcf

File tree

5 files changed

+445
-0
lines changed

5 files changed

+445
-0
lines changed

dbgpt/rag/operator/schema_linking.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Any, Optional
2+
3+
from dbgpt.core import LLMClient
4+
from dbgpt.core.awel import MapOperator
5+
from dbgpt.datasource.rdbms.base import RDBMSDatabase
6+
from dbgpt.rag.schemalinker.schema_linking import SchemaLinking
7+
from dbgpt.storage.vector_store.connector import VectorStoreConnector
8+
9+
10+
class SchemaLinkingOperator(MapOperator[Any, Any]):
11+
"""The Schema Linking Operator."""
12+
13+
def __init__(
14+
self,
15+
top_k: int = 5,
16+
connection: Optional[RDBMSDatabase] = None,
17+
llm: Optional[LLMClient] = None,
18+
model_name: Optional[str] = None,
19+
vector_store_connector: Optional[VectorStoreConnector] = None,
20+
**kwargs
21+
):
22+
"""Init the schema linking operator
23+
Args:
24+
connection (RDBMSDatabase): The connection.
25+
llm (Optional[LLMClient]): base llm
26+
"""
27+
super().__init__(**kwargs)
28+
29+
self._schema_linking = SchemaLinking(
30+
top_k=top_k,
31+
connection=connection,
32+
llm=llm,
33+
model_name=model_name,
34+
vector_store_connector=vector_store_connector,
35+
)
36+
37+
async def map(self, query: str) -> str:
38+
"""retrieve table schemas.
39+
Args:
40+
query (str): query.
41+
Return:
42+
str: schema info
43+
"""
44+
return str(await self._schema_linking.schema_linking_with_llm(query))

dbgpt/rag/schemalinker/__init__.py

Whitespace-only changes.

dbgpt/rag/schemalinker/base_linker.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List
3+
4+
5+
class BaseSchemaLinker(ABC):
6+
"""Base Linker."""
7+
8+
def schema_linking(self, query: str) -> List:
9+
"""
10+
Args:
11+
query (str): query text
12+
Returns:
13+
List: list of schema
14+
"""
15+
return self._schema_linking(query)
16+
17+
def schema_linking_with_vector_db(self, query: str) -> List:
18+
"""
19+
Args:
20+
query (str): query text
21+
Returns:
22+
List: list of schema
23+
"""
24+
return self._schema_linking_with_vector_db(query)
25+
26+
async def schema_linking_with_llm(self, query: str) -> List:
27+
""" "
28+
Args:
29+
query(str): query text
30+
Returns:
31+
List: list of schema
32+
"""
33+
return await self._schema_linking_with_llm(query)
34+
35+
@abstractmethod
36+
def _schema_linking(self, query: str) -> List:
37+
"""
38+
Args:
39+
query (str): query text
40+
Returns:
41+
List: list of schema
42+
"""
43+
44+
@abstractmethod
45+
def _schema_linking_with_vector_db(self, query: str) -> List:
46+
"""
47+
Args:
48+
query (str): query text
49+
Returns:
50+
List: list of schema
51+
"""
52+
53+
@abstractmethod
54+
async def _schema_linking_with_llm(self, query: str) -> List:
55+
"""
56+
Args:
57+
query (str): query text
58+
Returns:
59+
List: list of schema
60+
"""
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from functools import reduce
2+
from typing import List, Optional
3+
4+
from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest
5+
from dbgpt.datasource.rdbms.base import RDBMSDatabase
6+
from dbgpt.rag.chunk import Chunk
7+
from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker
8+
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
9+
from dbgpt.storage.vector_store.connector import VectorStoreConnector
10+
from dbgpt.util.chat_util import run_async_tasks
11+
12+
INSTRUCTION = (
13+
"You need to filter out the most relevant database table schema information (it may be a single "
14+
"table or multiple tables) required to generate the SQL of the question query from the given "
15+
"database schema information. First, I will show you an example of an instruction followed by "
16+
"the correct schema response. Then, I will give you a new instruction, and you should write "
17+
"the schema response that appropriately completes the request.\n### Example1 Instruction:\n"
18+
"['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)']\n### Example1 "
19+
"Input:\nFind the age of student table\n### Example1 Response:\n['student(id, name, age, info)']"
20+
"\n###New Instruction:\n{}"
21+
)
22+
INPUT_PROMPT = "\n###New Input:\n{}\n###New Response:"
23+
24+
25+
class SchemaLinking(BaseSchemaLinker):
26+
"""SchemaLinking by LLM"""
27+
28+
def __init__(
29+
self,
30+
top_k: int = 5,
31+
connection: Optional[RDBMSDatabase] = None,
32+
llm: Optional[LLMClient] = None,
33+
model_name: Optional[str] = None,
34+
vector_store_connector: Optional[VectorStoreConnector] = None,
35+
**kwargs
36+
):
37+
"""
38+
Args:
39+
connection (Optional[RDBMSDatabase]): RDBMSDatabase connection.
40+
llm (Optional[LLMClient]): base llm
41+
"""
42+
super().__init__(**kwargs)
43+
self._top_k = top_k
44+
self._connection = connection
45+
self._llm = llm
46+
self._model_name = model_name
47+
self._vector_store_connector = vector_store_connector
48+
49+
def _schema_linking(self, query: str) -> List:
50+
"""get all db schema info"""
51+
table_summaries = _parse_db_summary(self._connection)
52+
chunks = [Chunk(content=table_summary) for table_summary in table_summaries]
53+
chunks_content = [chunk.content for chunk in chunks]
54+
return chunks_content
55+
56+
def _schema_linking_with_vector_db(self, query: str) -> List:
57+
queries = [query]
58+
candidates = [
59+
self._vector_store_connector.similar_search(query, self._top_k)
60+
for query in queries
61+
]
62+
candidates = reduce(lambda x, y: x + y, candidates)
63+
return candidates
64+
65+
async def _schema_linking_with_llm(self, query: str) -> List:
66+
chunks_content = self.schema_linking(query)
67+
schema_prompt = INSTRUCTION.format(
68+
str(chunks_content) + INPUT_PROMPT.format(query)
69+
)
70+
messages = [
71+
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=schema_prompt)
72+
]
73+
request = ModelRequest(model=self._model_name, messages=messages)
74+
tasks = [self._llm.generate(request)]
75+
# get accurate schem info by llm
76+
schema = await run_async_tasks(tasks=tasks, concurrency_limit=1)
77+
schema_text = schema[0].text
78+
return schema_text

0 commit comments

Comments
 (0)