|
| 1 | +from functools import reduce |
| 2 | +from typing import List, Optional |
| 3 | + |
| 4 | +from dbgpt.core import LLMClient, ModelMessage, ModelMessageRoleType, ModelRequest |
| 5 | +from dbgpt.datasource.rdbms.base import RDBMSDatabase |
| 6 | +from dbgpt.rag.chunk import Chunk |
| 7 | +from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker |
| 8 | +from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary |
| 9 | +from dbgpt.storage.vector_store.connector import VectorStoreConnector |
| 10 | +from dbgpt.util.chat_util import run_async_tasks |
| 11 | + |
| 12 | +INSTRUCTION = ( |
| 13 | + "You need to filter out the most relevant database table schema information (it may be a single " |
| 14 | + "table or multiple tables) required to generate the SQL of the question query from the given " |
| 15 | + "database schema information. First, I will show you an example of an instruction followed by " |
| 16 | + "the correct schema response. Then, I will give you a new instruction, and you should write " |
| 17 | + "the schema response that appropriately completes the request.\n### Example1 Instruction:\n" |
| 18 | + "['job(id, name, age)', 'user(id, name, age)', 'student(id, name, age, info)']\n### Example1 " |
| 19 | + "Input:\nFind the age of student table\n### Example1 Response:\n['student(id, name, age, info)']" |
| 20 | + "\n###New Instruction:\n{}" |
| 21 | +) |
| 22 | +INPUT_PROMPT = "\n###New Input:\n{}\n###New Response:" |
| 23 | + |
| 24 | + |
| 25 | +class SchemaLinking(BaseSchemaLinker): |
| 26 | + """SchemaLinking by LLM""" |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + top_k: int = 5, |
| 31 | + connection: Optional[RDBMSDatabase] = None, |
| 32 | + llm: Optional[LLMClient] = None, |
| 33 | + model_name: Optional[str] = None, |
| 34 | + vector_store_connector: Optional[VectorStoreConnector] = None, |
| 35 | + **kwargs |
| 36 | + ): |
| 37 | + """ |
| 38 | + Args: |
| 39 | + connection (Optional[RDBMSDatabase]): RDBMSDatabase connection. |
| 40 | + llm (Optional[LLMClient]): base llm |
| 41 | + """ |
| 42 | + super().__init__(**kwargs) |
| 43 | + self._top_k = top_k |
| 44 | + self._connection = connection |
| 45 | + self._llm = llm |
| 46 | + self._model_name = model_name |
| 47 | + self._vector_store_connector = vector_store_connector |
| 48 | + |
| 49 | + def _schema_linking(self, query: str) -> List: |
| 50 | + """get all db schema info""" |
| 51 | + table_summaries = _parse_db_summary(self._connection) |
| 52 | + chunks = [Chunk(content=table_summary) for table_summary in table_summaries] |
| 53 | + chunks_content = [chunk.content for chunk in chunks] |
| 54 | + return chunks_content |
| 55 | + |
| 56 | + def _schema_linking_with_vector_db(self, query: str) -> List: |
| 57 | + queries = [query] |
| 58 | + candidates = [ |
| 59 | + self._vector_store_connector.similar_search(query, self._top_k) |
| 60 | + for query in queries |
| 61 | + ] |
| 62 | + candidates = reduce(lambda x, y: x + y, candidates) |
| 63 | + return candidates |
| 64 | + |
| 65 | + async def _schema_linking_with_llm(self, query: str) -> List: |
| 66 | + chunks_content = self.schema_linking(query) |
| 67 | + schema_prompt = INSTRUCTION.format( |
| 68 | + str(chunks_content) + INPUT_PROMPT.format(query) |
| 69 | + ) |
| 70 | + messages = [ |
| 71 | + ModelMessage(role=ModelMessageRoleType.SYSTEM, content=schema_prompt) |
| 72 | + ] |
| 73 | + request = ModelRequest(model=self._model_name, messages=messages) |
| 74 | + tasks = [self._llm.generate(request)] |
| 75 | + # get accurate schem info by llm |
| 76 | + schema = await run_async_tasks(tasks=tasks, concurrency_limit=1) |
| 77 | + schema_text = schema[0].text |
| 78 | + return schema_text |
0 commit comments