From e1e94f997a5e7e7ac3d7ec4b8bc64042d4263417 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sat, 15 Jun 2024 14:15:58 +0800 Subject: [PATCH] fix(rag): Fix schema linking error (#1637) --- dbgpt/rag/schemalinker/schema_linking.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/dbgpt/rag/schemalinker/schema_linking.py b/dbgpt/rag/schemalinker/schema_linking.py index 4bfd3f6ed..edb0af0eb 100644 --- a/dbgpt/rag/schemalinker/schema_linking.py +++ b/dbgpt/rag/schemalinker/schema_linking.py @@ -11,9 +11,9 @@ ModelRequest, ) from dbgpt.datasource.base import BaseConnector +from dbgpt.rag.index.base import IndexStoreBase from dbgpt.rag.schemalinker.base_linker import BaseSchemaLinker from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary -from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.util.chat_util import run_async_tasks INSTRUCTION = """ @@ -46,8 +46,7 @@ def __init__( model_name: str, llm: LLMClient, top_k: int = 5, - vector_store_connector: Optional[VectorStoreConnector] = None, - **kwargs + index_store: Optional[IndexStoreBase] = None, ): """Create the schema linking instance. @@ -55,12 +54,11 @@ def __init__( connection (Optional[BaseConnector]): BaseConnector connection. llm (Optional[LLMClient]): base llm """ - super().__init__(**kwargs) self._top_k = top_k self._connector = connector self._llm = llm self._model_name = model_name - self._vector_store_connector = vector_store_connector + self._index_store = index_store def _schema_linking(self, query: str) -> List: """Get all db schema info.""" @@ -71,11 +69,10 @@ def _schema_linking(self, query: str) -> List: def _schema_linking_with_vector_db(self, query: str) -> List[Chunk]: queries = [query] - if not self._vector_store_connector: + if not self._index_store: raise ValueError("Vector store connector is not provided.") candidates = [ - self._vector_store_connector.similar_search(query, self._top_k) - for query in queries + self._index_store.similar_search(query, self._top_k) for query in queries ] return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))