From ae6187483310cb79bb375455c2836246e3a77532 Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Tue, 17 Sep 2024 18:04:20 +0800 Subject: [PATCH 01/11] feat(rdb_summary): Support summary generation and retrieval of wide tables in relational databases --- .../app/scene/chat_db/professional_qa/chat.py | 11 +- dbgpt/rag/assembler/db_schema.py | 81 +++++++++-- .../tests/test_db_struct_assembler.py | 98 ++++++------- .../tests/test_embedding_assembler.py | 28 +++- dbgpt/rag/knowledge/datasource.py | 34 +++-- dbgpt/rag/operators/db_schema.py | 22 +-- dbgpt/rag/retriever/db_schema.py | 115 ++++++++-------- dbgpt/rag/retriever/tests/test_db_struct.py | 38 +++--- dbgpt/rag/summary/db_summary_client.py | 82 +++++++++-- dbgpt/rag/summary/rdbms_db_summary.py | 129 +++++++++++++++++- dbgpt/rag/text_splitter/text_splitter.py | 39 ++++++ examples/rag/db_schema_rag_example.py | 21 +-- 12 files changed, 502 insertions(+), 196 deletions(-) diff --git a/dbgpt/app/scene/chat_db/professional_qa/chat.py b/dbgpt/app/scene/chat_db/professional_qa/chat.py index 52ed56abc..0349e8f13 100644 --- a/dbgpt/app/scene/chat_db/professional_qa/chat.py +++ b/dbgpt/app/scene/chat_db/professional_qa/chat.py @@ -55,15 +55,8 @@ async def generate_input_values(self) -> Dict: if self.db_name: client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: - # table_infos = client.get_db_summary( - # dbname=self.db_name, query=self.current_user_input, topk=self.top_k - # ) - table_infos = await blocking_func_to_async( - self._executor, - client.get_db_summary, - self.db_name, - self.current_user_input, - self.top_k, + table_infos = await client.aget_db_summary( + dbname=self.db_name, query=self.current_user_input, topk=self.top_k ) except Exception as e: print("db summary find error!" + str(e)) diff --git a/dbgpt/rag/assembler/db_schema.py b/dbgpt/rag/assembler/db_schema.py index b55add20d..241d74e83 100644 --- a/dbgpt/rag/assembler/db_schema.py +++ b/dbgpt/rag/assembler/db_schema.py @@ -1,12 +1,13 @@ """DBSchemaAssembler.""" from typing import Any, List, Optional -from dbgpt.core import Chunk +from dbgpt.core import Chunk, Embeddings from dbgpt.datasource.base import BaseConnector +from ...serve.rag.connector import VectorStoreConnector from ..assembler.base import BaseAssembler from ..chunk_manager import ChunkParameters -from ..index.base import IndexStoreBase +from ..embedding.embedding_factory import DefaultEmbeddingFactory from ..knowledge.datasource import DatasourceKnowledge from ..retriever.db_schema import DBSchemaRetriever @@ -35,23 +36,56 @@ class DBSchemaAssembler(BaseAssembler): def __init__( self, connector: BaseConnector, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector, chunk_parameters: Optional[ChunkParameters] = None, + embedding_model: Optional[str] = None, + embeddings: Optional[Embeddings] = None, **kwargs: Any, ) -> None: """Initialize with Embedding Assembler arguments. Args: connector: (BaseConnector) BaseConnector connection. - index_store: (IndexStoreBase) IndexStoreBase to use. + table_vector_store_connector: VectorStoreConnector to load + and retrieve table info. + field_vector_store_connector: VectorStoreConnector to load + and retrieve field info. chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking. embedding_model: (Optional[str]) Embedding model to use. embeddings: (Optional[Embeddings]) Embeddings to use. """ - knowledge = DatasourceKnowledge(connector) self._connector = connector - self._index_store = index_store - + self._table_vector_store_connector = table_vector_store_connector + self._field_vector_store_connector = field_vector_store_connector + + self._embedding_model = embedding_model + if self._embedding_model and not embeddings: + embeddings = DefaultEmbeddingFactory( + default_model_name=self._embedding_model + ).create(self._embedding_model) + + if ( + embeddings + and self._table_vector_store_connector.vector_store_config.embedding_fn + is None + ): + self._table_vector_store_connector.vector_store_config.embedding_fn = ( + embeddings + ) + if ( + embeddings + and self._field_vector_store_connector.vector_store_config.embedding_fn + is None + ): + self._field_vector_store_connector.vector_store_config.embedding_fn = ( + embeddings + ) + max_seq_length = 512 + current_embeddings = self._table_vector_store_connector.current_embeddings + if current_embeddings: + max_seq_length = current_embeddings.client.max_seq_length # type: ignore + knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length) super().__init__( knowledge=knowledge, chunk_parameters=chunk_parameters, @@ -62,23 +96,33 @@ def __init__( def load_from_connection( cls, connector: BaseConnector, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector, chunk_parameters: Optional[ChunkParameters] = None, + embedding_model: Optional[str] = None, + embeddings: Optional[Embeddings] = None, ) -> "DBSchemaAssembler": """Load document embedding into vector store from path. Args: connector: (BaseConnector) BaseConnector connection. - index_store: (IndexStoreBase) IndexStoreBase to use. + table_vector_store_connector: used to load table chunks. + field_vector_store_connector: used to load field chunks + if field in table is too much. chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunking. + embedding_model: (Optional[str]) Embedding model to use. + embeddings: (Optional[Embeddings]) Embeddings to use. Returns: DBSchemaAssembler """ return cls( connector=connector, - index_store=index_store, + table_vector_store_connector=table_vector_store_connector, + field_vector_store_connector=field_vector_store_connector, + embedding_model=embedding_model, chunk_parameters=chunk_parameters, + embeddings=embeddings, ) def get_chunks(self) -> List[Chunk]: @@ -91,7 +135,19 @@ def persist(self, **kwargs: Any) -> List[str]: Returns: List[str]: List of chunk ids. """ - return self._index_store.load_document(self._chunks) + table_chunks, field_chunks = [], [] + for chunk in self._chunks: + metadata = chunk.metadata + if metadata.get("separated"): + if metadata.get("part") == "table": + table_chunks.append(chunk) + else: + field_chunks.append(chunk) + else: + table_chunks.append(chunk) + + self._field_vector_store_connector.load_document(field_chunks) + return self._table_vector_store_connector.load_document(table_chunks) def _extract_info(self, chunks) -> List[Chunk]: """Extract info from chunks.""" @@ -110,5 +166,6 @@ def as_retriever(self, top_k: int = 4, **kwargs) -> DBSchemaRetriever: top_k=top_k, connector=self._connector, is_embeddings=True, - index_store=self._index_store, + table_vector_store_connector=self._table_vector_store_connector, + field_vector_store_connector=self._field_vector_store_connector, ) diff --git a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py index 84638b692..f4d9cecad 100644 --- a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py +++ b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py @@ -1,76 +1,62 @@ -from unittest.mock import MagicMock +from typing import List +from unittest.mock import MagicMock, patch import pytest -from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector -from dbgpt.rag.assembler.embedding import EmbeddingAssembler -from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory -from dbgpt.rag.knowledge.base import Knowledge -from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter -from dbgpt.storage.vector_store.chroma_store import ChromaStore +import dbgpt +from dbgpt.core import Chunk +from dbgpt.rag.retriever.db_schema import DBSchemaRetriever +from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary @pytest.fixture def mock_db_connection(): - """Create a temporary database connection for testing.""" - connect = SQLiteTempConnector.create_temporary_db() - connect.create_temp_tables( - { - "user": { - "columns": { - "id": "INTEGER PRIMARY KEY", - "name": "TEXT", - "age": "INTEGER", - }, - "data": [ - (1, "Tom", 10), - (2, "Jerry", 16), - (3, "Jack", 18), - (4, "Alice", 20), - (5, "Bob", 22), - ], - } - } - ) - return connect + return MagicMock() @pytest.fixture -def mock_chunk_parameters(): - return MagicMock(spec=ChunkParameters) +def mock_table_vector_store_connector(): + mock_connector = MagicMock() + mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4 + return mock_connector @pytest.fixture -def mock_embedding_factory(): - return MagicMock(spec=EmbeddingFactory) +def mock_field_vector_store_connector(): + mock_connector = MagicMock() + mock_connector.similar_search.return_value = [Chunk(content="Field summary")] * 4 + return mock_connector @pytest.fixture -def mock_vector_store_connector(): - return MagicMock(spec=ChromaStore) +def dbstruct_retriever( + mock_db_connection, + mock_table_vector_store_connector, + mock_field_vector_store_connector, +): + return DBSchemaRetriever( + connector=mock_db_connection, + table_vector_store_connector=mock_table_vector_store_connector, + field_vector_store_connector=mock_field_vector_store_connector, + ) -@pytest.fixture -def mock_knowledge(): - return MagicMock(spec=Knowledge) +def mock_parse_db_summary() -> str: + """Patch _parse_db_summary method.""" + return "Table summary" -def test_load_knowledge( - mock_db_connection, - mock_knowledge, - mock_chunk_parameters, - mock_embedding_factory, - mock_vector_store_connector, -): - mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" - mock_chunk_parameters.text_splitter = CharacterTextSplitter() - mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE - assembler = EmbeddingAssembler( - knowledge=mock_knowledge, - chunk_parameters=mock_chunk_parameters, - embeddings=mock_embedding_factory.create(), - index_store=mock_vector_store_connector, - ) - assembler.load_knowledge(knowledge=mock_knowledge) - assert len(assembler._chunks) == 0 +# Mocking the _parse_db_summary method in your test function +@patch.object( + dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary +) +def test_retrieve_with_mocked_summary(dbstruct_retriever): + query = "Table summary" + chunks: List[Chunk] = dbstruct_retriever._retrieve(query) + assert isinstance(chunks[0], Chunk) + assert chunks[0].content == "Table summary" + + +async def async_mock_parse_db_summary() -> str: + """Asynchronous patch for _parse_db_summary method.""" + return "Table summary" diff --git a/dbgpt/rag/assembler/tests/test_embedding_assembler.py b/dbgpt/rag/assembler/tests/test_embedding_assembler.py index 350ccad39..2539d5ff8 100644 --- a/dbgpt/rag/assembler/tests/test_embedding_assembler.py +++ b/dbgpt/rag/assembler/tests/test_embedding_assembler.py @@ -6,8 +6,8 @@ from dbgpt.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory -from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter -from dbgpt.storage.vector_store.chroma_store import ChromaStore +from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter +from dbgpt.serve.rag.connector import VectorStoreConnector @pytest.fixture @@ -46,23 +46,37 @@ def mock_embedding_factory(): @pytest.fixture -def mock_vector_store_connector(): - return MagicMock(spec=ChromaStore) +def mock_table_vector_store_connector(): + mock_connector = MagicMock(spec=VectorStoreConnector) + mock_connector.current_embeddings.client.max_seq_length = 512 + # print(mock_connector.vector_store_config.embedding_fn.client.max_seq_length()) + return mock_connector + + +@pytest.fixture +def mock_field_vector_store_connector(): + mock_connector = MagicMock(spec=VectorStoreConnector) + mock_connector.current_embeddings.client.max_seq_length = 512 + return mock_connector def test_load_knowledge( mock_db_connection, mock_chunk_parameters, mock_embedding_factory, - mock_vector_store_connector, + mock_table_vector_store_connector, + mock_field_vector_store_connector, ): mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" - mock_chunk_parameters.text_splitter = CharacterTextSplitter() + mock_chunk_parameters.text_splitter = RDBTextSplitter( + separator="--table-field-separator--" + ) mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE assembler = DBSchemaAssembler( connector=mock_db_connection, chunk_parameters=mock_chunk_parameters, embeddings=mock_embedding_factory.create(), - index_store=mock_vector_store_connector, + table_vector_store_connector=mock_table_vector_store_connector, + field_vector_store_connector=mock_field_vector_store_connector, ) assert len(assembler._chunks) == 1 diff --git a/dbgpt/rag/knowledge/datasource.py b/dbgpt/rag/knowledge/datasource.py index 78ae045dd..504e18dc4 100644 --- a/dbgpt/rag/knowledge/datasource.py +++ b/dbgpt/rag/knowledge/datasource.py @@ -5,7 +5,7 @@ from dbgpt.datasource import BaseConnector from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary -from ..summary.rdbms_db_summary import _parse_db_summary +from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType @@ -15,9 +15,11 @@ class DatasourceKnowledge(Knowledge): def __init__( self, connector: BaseConnector, - summary_template: str = "{table_name}({columns})", + summary_template: str = "table_name: {table_name}", + separator: str = "--table-field-separator--", knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, metadata: Optional[Dict[str, Union[str, List[str]]]] = None, + model_dimension: int = 512, **kwargs: Any, ) -> None: """Create Datasource Knowledge with Knowledge arguments. @@ -25,11 +27,17 @@ def __init__( Args: connector(BaseConnector): connector summary_template(str, optional): summary template + separator(str, optional): separator used to separate + table's basic info and fields. + defaults `-- table-field-separator--` knowledge_type(KnowledgeType, optional): knowledge type metadata(Dict[str, Union[str, List[str]], optional): metadata + model_dimension(int, optional): The threshold for splitting field string """ + self._separator = separator self._connector = connector self._summary_template = summary_template + self._model_dimension = model_dimension super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs) def _load(self) -> List[Document]: @@ -37,13 +45,23 @@ def _load(self) -> List[Document]: docs = [] if self._connector.is_graph_type(): db_summary = _parse_gdb_summary(self._connector, self._summary_template) + for table_summary in db_summary: + metadata = {"source": "database"} + docs.append(Document(content=table_summary, metadata=metadata)) else: - db_summary = _parse_db_summary(self._connector, self._summary_template) - for table_summary in db_summary: - metadata = {"source": "database"} - if self._metadata: - metadata.update(self._metadata) # type: ignore - docs.append(Document(content=table_summary, metadata=metadata)) + db_summary_with_metadata = _parse_db_summary_with_metadata( + self._connector, + self._summary_template, + self._separator, + self._model_dimension, + ) + for summary, table_metadata in db_summary_with_metadata: + metadata = {"source": "database"} + + if self._metadata: + metadata.update(self._metadata) # type: ignore + table_metadata.update(metadata) + docs.append(Document(content=summary, metadata=table_metadata)) return docs @classmethod diff --git a/dbgpt/rag/operators/db_schema.py b/dbgpt/rag/operators/db_schema.py index d0a7c0d9f..8ec2d5b60 100644 --- a/dbgpt/rag/operators/db_schema.py +++ b/dbgpt/rag/operators/db_schema.py @@ -2,13 +2,14 @@ from typing import List, Optional +from dbgpt.serve.rag.connector import VectorStoreConnector + from dbgpt.core import Chunk from dbgpt.core.interface.operators.retriever import RetrieverOperator from dbgpt.datasource.base import BaseConnector from ..assembler.db_schema import DBSchemaAssembler from ..chunk_manager import ChunkParameters -from ..index.base import IndexStoreBase from ..retriever.db_schema import DBSchemaRetriever from .assembler import AssemblerOperator @@ -19,13 +20,14 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]): Args: connector (BaseConnector): The connection. top_k (int, optional): The top k. Defaults to 4. - index_store (IndexStoreBase, optional): The vector store + vector_store_connector (VectorStoreConnector, optional): The vector store connector. Defaults to None. """ def __init__( self, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector, top_k: int = 4, connector: Optional[BaseConnector] = None, **kwargs @@ -35,7 +37,8 @@ def __init__( self._retriever = DBSchemaRetriever( top_k=top_k, connector=connector, - index_store=index_store, + table_vector_store_connector=table_vector_store_connector, + field_vector_store_connector=field_vector_store_connector, ) def retrieve(self, query: str) -> List[Chunk]: @@ -53,7 +56,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]): def __init__( self, connector: BaseConnector, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector, chunk_parameters: Optional[ChunkParameters] = None, **kwargs ): @@ -61,14 +65,15 @@ def __init__( Args: connector (BaseConnector): The connection. - index_store (IndexStoreBase): The Storage IndexStoreBase. + vector_store_connector (VectorStoreConnector): The vector store connector. chunk_parameters (Optional[ChunkParameters], optional): The chunk parameters. """ if not chunk_parameters: chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") self._chunk_parameters = chunk_parameters - self._index_store = index_store + self._table_vector_store_connector = table_vector_store_connector + self._field_vector_store_connector = field_vector_store_connector self._connector = connector super().__init__(**kwargs) @@ -84,7 +89,8 @@ def assemble(self, dummy_value) -> List[Chunk]: assembler = DBSchemaAssembler.load_from_connection( connector=self._connector, chunk_parameters=self._chunk_parameters, - index_store=self._index_store, + table_vector_store_connector=self._table_vector_store_connector, + field_vector_store_connector=self._field_vector_store_connector, ) assembler.persist() return assembler.get_chunks() diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index 9bced9267..bcb13015b 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -1,18 +1,15 @@ """DBSchema retriever.""" - from functools import reduce from typing import List, Optional, cast from dbgpt.core import Chunk from dbgpt.datasource.base import BaseConnector -from dbgpt.rag.index.base import IndexStoreBase from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary -from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.serve.rag.connector import VectorStoreConnector +from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters from dbgpt.util.chat_util import run_async_tasks -from dbgpt.util.executor_utils import blocking_func_to_async_no_executor -from dbgpt.util.tracer import root_tracer class DBSchemaRetriever(BaseRetriever): @@ -20,7 +17,9 @@ class DBSchemaRetriever(BaseRetriever): def __init__( self, - index_store: IndexStoreBase, + table_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector, + separator: str = "--table-field-separator--", top_k: int = 4, connector: Optional[BaseConnector] = None, query_rewrite: bool = False, @@ -30,7 +29,11 @@ def __init__( """Create DBSchemaRetriever. Args: - index_store(IndexStore): index connector + table_vector_store_connector: VectorStoreConnector + to load and retrieve table info. + field_vector_store_connector: VectorStoreConnector + to load and retrieve field info. + separator: field/table separator top_k (int): top k connector (Optional[BaseConnector]): RDBMSConnector. query_rewrite (bool): query rewrite @@ -70,34 +73,32 @@ def _create_temporary_connection(): connector = _create_temporary_connection() + vector_store_config = ChromaVectorConfig(name="vector_store_name") + embedding_model_path = "{your_embedding_model_path}" embedding_fn = embedding_factory.create(model_name=embedding_model_path) - config = ChromaVectorConfig( - persist_path=PILOT_PATH, - name="dbschema_rag_test", - embedding_fn=DefaultEmbeddingFactory( - default_model_name=os.path.join( - MODEL_PATH, "text2vec-large-chinese" - ), - ).create(), + vector_connector = VectorStoreConnector.from_default( + "Chroma", + vector_store_config=vector_store_config, + embedding_fn=embedding_fn, ) - - vector_store = ChromaStore(config) # get db struct retriever retriever = DBSchemaRetriever( top_k=3, - index_store=vector_store, + vector_store_connector=vector_connector, connector=connector, ) chunks = retriever.retrieve("show columns from table") result = [chunk.content for chunk in chunks] print(f"db struct rag example results:{result}") """ + self._separator = separator self._top_k = top_k self._connector = connector self._query_rewrite = query_rewrite - self._index_store = index_store + self._table_vector_store_connector = table_vector_store_connector + self._field_vector_store_connector = field_vector_store_connector self._need_embeddings = False - if self._index_store: + if self._table_vector_store_connector: self._need_embeddings = True self._rerank = rerank or DefaultRanker(self._top_k) @@ -116,7 +117,9 @@ def _retrieve( if self._need_embeddings: queries = [query] candidates = [ - self._index_store.similar_search(query, self._top_k, filters) + self._table_vector_store_connector.similar_search( + query, self._top_k, filters + ) for query in queries ] return cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) @@ -157,29 +160,20 @@ async def _aretrieve( List[Chunk]: list of chunks """ if self._need_embeddings: - queries = [query] - candidates = [ - self._similarity_search( - query, filters, root_tracer.get_current_span_id() - ) - for query in queries - ] + candidates = [self._similarity_search(query, filters)] result_candidates = await run_async_tasks( - tasks=candidates, concurrency_limit=1 + tasks=candidates, concurrency_limit=3 ) - return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates)) + return result_candidates[0] else: from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401 _parse_db_summary, ) table_summaries = await run_async_tasks( - tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())], - concurrency_limit=1, + tasks=[self._aparse_db_summary()], concurrency_limit=1 ) - return [ - Chunk(content=table_summary) for table_summary in table_summaries[0] - ] + return [Chunk(content=table_summary) for table_summary in table_summaries] async def _aretrieve_with_score( self, @@ -196,34 +190,41 @@ async def _aretrieve_with_score( """ return await self._aretrieve(query, filters) + async def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk: + metadata = table_chunk.metadata + metadata["part"] = "field" + filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()] + field_chunks = self._field_vector_store_connector.similar_search( + query, self._top_k, MetadataFilters(filters=filters) + ) + field_contents = [chunk.content for chunk in field_chunks] + table_chunk.content += self._separator + "\n".join(field_contents) + return table_chunk + async def _similarity_search( - self, - query, - filters: Optional[MetadataFilters] = None, - parent_span_id: Optional[str] = None, + self, query, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: """Similar search.""" - with root_tracer.start_span( - "dbgpt.rag.retriever.db_schema._similarity_search", - parent_span_id, - metadata={"query": query}, - ): - return await blocking_func_to_async_no_executor( - self._index_store.similar_search, query, self._top_k, filters + table_chunks = ( + await self._table_vector_store_connector.asimilar_search_with_scores( + query, self._top_k, 0, filters ) - - async def _aparse_db_summary( - self, parent_span_id: Optional[str] = None - ) -> List[str]: + ) + not_sep_chunks = [ + chunk for chunk in table_chunks if not chunk.metadata.get("separated") + ] + separated_chunks = [ + chunk for chunk in table_chunks if chunk.metadata.get("separated") + ] + separated_result = await run_async_tasks( + tasks=[self._retrieve_field(chunk, query) for chunk in separated_chunks] + ) + return not_sep_chunks + separated_result + + async def _aparse_db_summary(self) -> List[str]: """Similar search.""" from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary if not self._connector: raise RuntimeError("RDBMSConnector connection is required.") - with root_tracer.start_span( - "dbgpt.rag.retriever.db_schema._aparse_db_summary", - parent_span_id, - ): - return await blocking_func_to_async_no_executor( - _parse_db_summary, self._connector - ) + return _parse_db_summary(self._connector) diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index 0b667f69e..f4d9cecad 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -15,42 +15,48 @@ def mock_db_connection(): @pytest.fixture -def mock_vector_store_connector(): +def mock_table_vector_store_connector(): mock_connector = MagicMock() mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4 return mock_connector @pytest.fixture -def db_struct_retriever(mock_db_connection, mock_vector_store_connector): +def mock_field_vector_store_connector(): + mock_connector = MagicMock() + mock_connector.similar_search.return_value = [Chunk(content="Field summary")] * 4 + return mock_connector + + +@pytest.fixture +def dbstruct_retriever( + mock_db_connection, + mock_table_vector_store_connector, + mock_field_vector_store_connector, +): return DBSchemaRetriever( connector=mock_db_connection, - index_store=mock_vector_store_connector, + table_vector_store_connector=mock_table_vector_store_connector, + field_vector_store_connector=mock_field_vector_store_connector, ) -def mock_parse_db_summary(conn) -> List[str]: +def mock_parse_db_summary() -> str: """Patch _parse_db_summary method.""" - return ["Table summary"] + return "Table summary" # Mocking the _parse_db_summary method in your test function @patch.object( dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary ) -def test_retrieve_with_mocked_summary(db_struct_retriever): +def test_retrieve_with_mocked_summary(dbstruct_retriever): query = "Table summary" - chunks: List[Chunk] = db_struct_retriever._retrieve(query) + chunks: List[Chunk] = dbstruct_retriever._retrieve(query) assert isinstance(chunks[0], Chunk) assert chunks[0].content == "Table summary" -@pytest.mark.asyncio -@patch.object( - dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary -) -async def test_aretrieve_with_mocked_summary(db_struct_retriever): - query = "Table summary" - chunks: List[Chunk] = await db_struct_retriever._aretrieve(query) - assert isinstance(chunks[0], Chunk) - assert chunks[0].content == "Table summary" +async def async_mock_parse_db_summary() -> str: + """Asynchronous patch for _parse_db_summary method.""" + return "Table summary" diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 8ce9a79e6..158d0c359 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -4,11 +4,15 @@ import traceback from typing import List +from dbgpt.serve.rag.connector import VectorStoreConnector + from dbgpt._private.config import Config from dbgpt.component import SystemApp from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG +from dbgpt.rag import ChunkParameters from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary +from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter logger = logging.getLogger(__name__) @@ -47,26 +51,66 @@ def db_summary_embedding(self, dbname, db_type): logger.info("db summary embedding success") - def get_db_summary(self, dbname, query, topk) -> List[str]: + def get_db_summary(self, dbname, query, topk): """Get user query related tables info.""" - from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig - vector_store_config = VectorStoreConfig(name=dbname + "_profile") - vector_connector = VectorStoreConnector.from_default( + vector_store_name = dbname + "_profile" + table_vector_store_config = VectorStoreConfig(name=vector_store_name + "_table") + field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") + table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, - embedding_fn=self.embeddings, - vector_store_config=vector_store_config, + self.embeddings, + vector_store_config=table_vector_store_config, + ) + field_vector_connector = VectorStoreConnector.from_default( + CFG.VECTOR_STORE_TYPE, + self.embeddings, + vector_store_config=field_vector_store_config, ) from dbgpt.rag.retriever.db_schema import DBSchemaRetriever retriever = DBSchemaRetriever( - top_k=topk, index_store=vector_connector.index_client + top_k=topk, + table_vector_store_connector=table_vector_connector, + field_vector_store_connector=field_vector_connector, + separator="--table-field-separator--", ) + table_docs = retriever.retrieve(query) ans = [d.content for d in table_docs] return ans + async def aget_db_summary(self, dbname, query, topk): + """Get user query related tables info.""" + from dbgpt.serve.rag.connector import VectorStoreConnector + from dbgpt.storage.vector_store.base import VectorStoreConfig + + vector_store_name = dbname + "_profile" + table_vector_store_config = VectorStoreConfig(name=vector_store_name + "_table") + field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") + table_vector_connector = VectorStoreConnector.from_default( + CFG.VECTOR_STORE_TYPE, + self.embeddings, + vector_store_config=table_vector_store_config, + ) + field_vector_connector = VectorStoreConnector.from_default( + CFG.VECTOR_STORE_TYPE, + self.embeddings, + vector_store_config=field_vector_store_config, + ) + from dbgpt.rag.retriever.db_schema import DBSchemaRetriever + + retriever = DBSchemaRetriever( + top_k=topk, + table_vector_store_connector=table_vector_connector, + field_vector_store_connector=field_vector_connector, + ) + + table_docs = await retriever.aretrieve(query) + ans = [d.content for d in table_docs] + return ans + def init_db_summary(self): """Initialize db summary profile.""" db_mange = CFG.local_db_manager @@ -92,18 +136,32 @@ def init_db_profile(self, db_summary_client, dbname): from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig - vector_store_config = VectorStoreConfig(name=vector_store_name) - vector_connector = VectorStoreConnector.from_default( + table_vector_store_config = VectorStoreConfig(name=vector_store_name + "_table") + field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") + table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, self.embeddings, - vector_store_config=vector_store_config, + vector_store_config=table_vector_store_config, ) - if not vector_connector.vector_name_exists(): + field_vector_connector = VectorStoreConnector.from_default( + CFG.VECTOR_STORE_TYPE, + self.embeddings, + vector_store_config=field_vector_store_config, + ) + if ( + not table_vector_connector.vector_name_exists() + or not field_vector_connector.vector_name_exists() + ): from dbgpt.rag.assembler.db_schema import DBSchemaAssembler + chunk_parameters = ChunkParameters( + text_splitter=RDBTextSplitter(separator="--table-field-separator--") + ) db_assembler = DBSchemaAssembler.load_from_connection( connector=db_summary_client.db, - index_store=vector_connector.index_client, + table_vector_store_connector=table_vector_connector, + field_vector_store_connector=field_vector_connector, + chunk_parameters=chunk_parameters, ) if len(db_assembler.get_chunks()) > 0: diff --git a/dbgpt/rag/summary/rdbms_db_summary.py b/dbgpt/rag/summary/rdbms_db_summary.py index 337d3851b..c0bd35972 100644 --- a/dbgpt/rag/summary/rdbms_db_summary.py +++ b/dbgpt/rag/summary/rdbms_db_summary.py @@ -1,6 +1,6 @@ """Summary for rdbms database.""" import re -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from dbgpt._private.config import Config from dbgpt.datasource import BaseConnector @@ -80,6 +80,133 @@ def _parse_db_summary( return table_info_summaries +def _parse_db_summary_with_metadata( + conn: BaseConnector, + summary_template: str = "table_name: {table_name}", + separator: str = "--table-field-separator--", + model_dimension: int = 512, +) -> List[Tuple[str, Dict[str, Any]]]: + """Get db summary for database. + + Args: + conn (BaseConnector): database connection + summary_template (str): summary template + separator(str, optional): separator used to separate table's + basic info and fields. defaults to `-- table-field-separator--` + model_dimension(int, optional): The threshold for splitting field string + """ + tables = conn.get_table_names() + table_info_summaries = [ + _parse_table_summary_with_metadata( + conn, summary_template, separator, table_name, model_dimension + ) + for table_name in tables + ] + return table_info_summaries + + +def _split_columns_str(columns: List[str], model_dimension: int): + """Split columns str. + + Args: + columns (List[str]): fields string + model_dimension (int, optional): The threshold for splitting field string. + """ + result = [] + current_string = "" + current_length = 0 + + for element_str in columns: + element_length = len(element_str) + + # 如果加上当前元素的长度会超过阈值,则将当前字符串添加到结果中,并重置 + if current_length + element_length > model_dimension: + result.append(current_string.strip()) # 去掉末尾的空格 + current_string = element_str + current_length = element_length + else: + # 如果当前字符串为空,直接添加元素 + if current_string: + current_string += "," + element_str + else: + current_string = element_str + current_length += element_length + 1 # 加上空格的长度 + + # 最后一段字符串 + if current_string: + result.append(current_string.strip()) + + return result + + +def _parse_table_summary_with_metadata( + conn: BaseConnector, + summary_template: str, + separator, + table_name: str, + model_dimension=512, +) -> Tuple[str, Dict[str, Any]]: + """Get table summary for table. + + Args: + conn (BaseConnector): database connection + summary_template (str): summary template + separator(str, optional): separator used to separate table's + basic info and fields. defaults to `-- table-field-separator--` + model_dimension(int, optional): The threshold for splitting field string + + Examples: + metadata: {'table_name': 'asd', 'separated': 0/1} + + table_name: table1 + table_comment: comment + index_keys: keys + --table-field-separator-- + (column1,comment), (column2, comment), (column3, comment) + (column4,comment), (column5, comment), (column6, comment) + """ + columns = [] + metadata = {"table_name": table_name, "separated": 0} + for column in conn.get_columns(table_name): + if column.get("comment"): + columns.append(f"{column['name']} ({column.get('comment')})") + else: + columns.append(f"{column['name']}") + metadata.update({"field_num": len(columns)}) + separated_columns = _split_columns_str(columns, model_dimension=model_dimension) + if len(separated_columns) > 1: + metadata["separated"] = 1 + column_str = "\n".join(separated_columns) + # Obtain index information + index_keys = [] + raw_indexes = conn.get_indexes(table_name) + for index in raw_indexes: + if isinstance(index, tuple): # Process tuple type index information + index_name, index_creation_command = index + # Extract column names using re + matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command) + if matched_columns: + key_str = ", ".join(matched_columns) + index_keys.append(f"{index_name}(`{key_str}`) ") + else: + key_str = ", ".join(index["column_names"]) + index_keys.append(f"{index['name']}(`{key_str}`) ") + table_str = summary_template.format(table_name=table_name) + + try: + comment = conn.get_table_comment(table_name) + except Exception: + comment = dict(text=None) + if comment.get("text"): + table_str += f"\ntable_comment: {comment.get('text')}" + + if len(index_keys) > 0: + index_key_str = ", ".join(index_keys) + table_str += f"\nindex_keys: {index_key_str}" + table_str += f"\n{separator}\n{column_str}" + return table_str, metadata + + def _parse_table_summary( conn: BaseConnector, summary_template: str, table_name: str ) -> str: diff --git a/dbgpt/rag/text_splitter/text_splitter.py b/dbgpt/rag/text_splitter/text_splitter.py index a4b44ca8e..c2e79ad33 100644 --- a/dbgpt/rag/text_splitter/text_splitter.py +++ b/dbgpt/rag/text_splitter/text_splitter.py @@ -902,3 +902,42 @@ def create_documents( new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i])) chunks.append(new_doc) return chunks + + +class RDBTextSplitter(TextSplitter): + """Split relational database tables and fields.""" + + def __init__(self, **kwargs): + """Create a new TextSplitter.""" + super().__init__(**kwargs) + + def split_text(self, text: str, **kwargs): + """Split text into a couple of parts.""" + pass + + def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]: + """Split document into chunks.""" + chunks = [] + for doc in documents: + metadata = doc.metadata + content = doc.content + if metadata.get("separated"): + # separate table and field + parts = content.split(self._separator) + table_part, field_part = parts[0], parts[1] + table_metadata, field_metadata = copy.deepcopy(metadata), copy.deepcopy( + metadata + ) + table_metadata["part"] = "table" # identify of table_chunk + field_metadata["part"] = "field" # identify of field_chunk + table_chunk = Chunk(content=table_part, metadata=table_metadata) + chunks.append(table_chunk) + field_parts = field_part.split("\n") + for i, sub_part in enumerate(field_parts): + sub_metadata = copy.deepcopy(field_metadata) + sub_metadata["part_index"] = i + field_chunk = Chunk(content=sub_part, metadata=sub_metadata) + chunks.append(field_chunk) + else: + chunks.append(Chunk(content=content, metadata=metadata)) + return chunks diff --git a/examples/rag/db_schema_rag_example.py b/examples/rag/db_schema_rag_example.py index 1524634fa..7fe1ec15e 100644 --- a/examples/rag/db_schema_rag_example.py +++ b/examples/rag/db_schema_rag_example.py @@ -4,7 +4,8 @@ from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.assembler import DBSchemaAssembler from dbgpt.rag.embedding import DefaultEmbeddingFactory -from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector """DB struct rag example. pre-requirements: @@ -12,7 +13,7 @@ ``` embedding_model_path = "{your_embedding_model_path}" ``` - + Examples: ..code-block:: shell python examples/rag/db_schema_rag_example.py @@ -45,27 +46,27 @@ def _create_temporary_connection(): def _create_vector_connector(): """Create vector connector.""" - config = ChromaVectorConfig( - persist_path=PILOT_PATH, - name="dbschema_rag_test", + return VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name="db_schema_vector_store_name", + persist_path=os.path.join(PILOT_PATH, "data"), + ), embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) - return ChromaStore(config) - if __name__ == "__main__": connection = _create_temporary_connection() - index_store = _create_vector_connector() + vector_connector = _create_vector_connector() assembler = DBSchemaAssembler.load_from_connection( connector=connection, - index_store=index_store, + vector_store_connector=vector_connector, ) assembler.persist() # get db schema retriever retriever = assembler.as_retriever(top_k=1) chunks = retriever.retrieve("show columns from user") print(f"db schema rag example results:{[chunk.content for chunk in chunks]}") - index_store.delete_vector_name("dbschema_rag_test") From 6f948e12e991d03c86dc94d1ed66a58514a60808 Mon Sep 17 00:00:00 2001 From: dong Date: Sun, 22 Sep 2024 17:47:47 +0800 Subject: [PATCH 02/11] feat:unit test for wide table summary and retrival --- .../tests/test_db_struct_assembler.py | 71 ++++++++++++++++--- .../tests/test_embedding_assembler.py | 31 ++++---- dbgpt/rag/operators/db_schema.py | 3 +- dbgpt/rag/retriever/db_schema.py | 8 ++- dbgpt/rag/summary/db_summary_client.py | 4 +- 5 files changed, 89 insertions(+), 28 deletions(-) diff --git a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py index f4d9cecad..6f8f59cf1 100644 --- a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py +++ b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py @@ -1,5 +1,5 @@ from typing import List -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -17,14 +17,55 @@ def mock_db_connection(): @pytest.fixture def mock_table_vector_store_connector(): mock_connector = MagicMock() - mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4 + chunk = Chunk( + content="table_name: user\ncomment: user about dbgpt", + metadata={ + "field_num": 6, + "part": "table", + "separated": 1, + "table_name": "user", + }, + ) + mock_connector.asimilar_search_with_scores = AsyncMock(return_value=[chunk]) return mock_connector @pytest.fixture def mock_field_vector_store_connector(): mock_connector = MagicMock() - mock_connector.similar_search.return_value = [Chunk(content="Field summary")] * 4 + chunk1 = Chunk( + content="name,age", + metadata={ + "field_num": 6, + "part": "field", + "part_index": 0, + "separated": 1, + "table_name": "user", + }, + ) + chunk2 = Chunk( + content="address,gender", + metadata={ + "field_num": 6, + "part": "field", + "part_index": 1, + "separated": 1, + "table_name": "user", + }, + ) + chunk3 = Chunk( + content="mail,phone", + metadata={ + "field_num": 6, + "part": "field", + "part_index": 2, + "separated": 1, + "table_name": "user", + }, + ) + mock_connector.asimilar_search_with_scores = AsyncMock( + return_value=[chunk1, chunk2, chunk3] + ) return mock_connector @@ -38,25 +79,39 @@ def dbstruct_retriever( connector=mock_db_connection, table_vector_store_connector=mock_table_vector_store_connector, field_vector_store_connector=mock_field_vector_store_connector, + separator="--table-field-separator--", ) def mock_parse_db_summary() -> str: """Patch _parse_db_summary method.""" - return "Table summary" + return ( + "table_name: user\ncomment: user about dbgpt\n" + "--table-field-separator--\n" + "name,age\naddress,gender\nmail,phone" + ) # Mocking the _parse_db_summary method in your test function @patch.object( dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary ) -def test_retrieve_with_mocked_summary(dbstruct_retriever): +@pytest.mark.asyncio +async def test_retrieve_with_mocked_summary(dbstruct_retriever): query = "Table summary" - chunks: List[Chunk] = dbstruct_retriever._retrieve(query) + chunks: List[Chunk] = await dbstruct_retriever._aretrieve(query) assert isinstance(chunks[0], Chunk) - assert chunks[0].content == "Table summary" + assert chunks[0].content == ( + "table_name: user\ncomment: user about dbgpt\n" + "--table-field-separator--\n" + "name,age\naddress,gender\nmail,phone" + ) async def async_mock_parse_db_summary() -> str: """Asynchronous patch for _parse_db_summary method.""" - return "Table summary" + return ( + "table_name: user\ncomment: user about dbgpt\n" + "--table-field-separator--\n" + "name,age\naddress,gender\nmail,phone" + ) diff --git a/dbgpt/rag/assembler/tests/test_embedding_assembler.py b/dbgpt/rag/assembler/tests/test_embedding_assembler.py index 2539d5ff8..10d5f04d1 100644 --- a/dbgpt/rag/assembler/tests/test_embedding_assembler.py +++ b/dbgpt/rag/assembler/tests/test_embedding_assembler.py @@ -21,14 +21,22 @@ def mock_db_connection(): "id": "INTEGER PRIMARY KEY", "name": "TEXT", "age": "INTEGER", - }, - "data": [ - (1, "Tom", 10), - (2, "Jerry", 16), - (3, "Jack", 18), - (4, "Alice", 20), - (5, "Bob", 22), - ], + "address": "TEXT", + "phone": "TEXT", + "email": "TEXT", + "gender": "TEXT", + "birthdate": "TEXT", + "occupation": "TEXT", + "education": "TEXT", + "marital_status": "TEXT", + "nationality": "TEXT", + "height": "REAL", + "weight": "REAL", + "blood_type": "TEXT", + "emergency_contact": "TEXT", + "created_at": "TEXT", + "updated_at": "TEXT", + } } } ) @@ -48,15 +56,14 @@ def mock_embedding_factory(): @pytest.fixture def mock_table_vector_store_connector(): mock_connector = MagicMock(spec=VectorStoreConnector) - mock_connector.current_embeddings.client.max_seq_length = 512 - # print(mock_connector.vector_store_config.embedding_fn.client.max_seq_length()) + mock_connector.current_embeddings.client.max_seq_length = 10 return mock_connector @pytest.fixture def mock_field_vector_store_connector(): mock_connector = MagicMock(spec=VectorStoreConnector) - mock_connector.current_embeddings.client.max_seq_length = 512 + mock_connector.current_embeddings.client.max_seq_length = 10 return mock_connector @@ -79,4 +86,4 @@ def test_load_knowledge( table_vector_store_connector=mock_table_vector_store_connector, field_vector_store_connector=mock_field_vector_store_connector, ) - assert len(assembler._chunks) == 1 + assert len(assembler._chunks) > 1 diff --git a/dbgpt/rag/operators/db_schema.py b/dbgpt/rag/operators/db_schema.py index 8ec2d5b60..15c6986d6 100644 --- a/dbgpt/rag/operators/db_schema.py +++ b/dbgpt/rag/operators/db_schema.py @@ -2,11 +2,10 @@ from typing import List, Optional -from dbgpt.serve.rag.connector import VectorStoreConnector - from dbgpt.core import Chunk from dbgpt.core.interface.operators.retriever import RetrieverOperator from dbgpt.datasource.base import BaseConnector +from dbgpt.serve.rag.connector import VectorStoreConnector from ..assembler.db_schema import DBSchemaAssembler from ..chunk_manager import ChunkParameters diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index bcb13015b..2b2700c75 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -194,11 +194,13 @@ async def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk: metadata = table_chunk.metadata metadata["part"] = "field" filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()] - field_chunks = self._field_vector_store_connector.similar_search( - query, self._top_k, MetadataFilters(filters=filters) + field_chunks = ( + await self._field_vector_store_connector.asimilar_search_with_scores( + query, self._top_k, 0, MetadataFilters(filters=filters) + ) ) field_contents = [chunk.content for chunk in field_chunks] - table_chunk.content += self._separator + "\n".join(field_contents) + table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents) return table_chunk async def _similarity_search( diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 158d0c359..889c9f1fc 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -2,9 +2,6 @@ import logging import traceback -from typing import List - -from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt._private.config import Config from dbgpt.component import SystemApp @@ -13,6 +10,7 @@ from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter +from dbgpt.serve.rag.connector import VectorStoreConnector logger = logging.getLogger(__name__) From c4f2bd942b07eaf5cb41309b1a05b181845df2b3 Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Thu, 12 Dec 2024 14:29:20 +0800 Subject: [PATCH 03/11] feat(rdbsummary): Ensure that old data is available --- dbgpt/rag/summary/db_summary_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 889c9f1fc..7c4bc5595 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -54,7 +54,7 @@ def get_db_summary(self, dbname, query, topk): from dbgpt.storage.vector_store.base import VectorStoreConfig vector_store_name = dbname + "_profile" - table_vector_store_config = VectorStoreConfig(name=vector_store_name + "_table") + table_vector_store_config = VectorStoreConfig(name=vector_store_name) field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, @@ -85,7 +85,7 @@ async def aget_db_summary(self, dbname, query, topk): from dbgpt.storage.vector_store.base import VectorStoreConfig vector_store_name = dbname + "_profile" - table_vector_store_config = VectorStoreConfig(name=vector_store_name + "_table") + table_vector_store_config = VectorStoreConfig(name=vector_store_name) field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, @@ -134,7 +134,7 @@ def init_db_profile(self, db_summary_client, dbname): from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig - table_vector_store_config = VectorStoreConfig(name=vector_store_name + "_table") + table_vector_store_config = VectorStoreConfig(name=vector_store_name) field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, From b536a994e10317890236b312166ce5bdd1f8b61f Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Thu, 12 Dec 2024 14:59:44 +0800 Subject: [PATCH 04/11] feat(rdbsummary): Ensure that old data is available --- dbgpt/rag/summary/db_summary_client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 7c4bc5595..32febb4b8 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -146,10 +146,7 @@ def init_db_profile(self, db_summary_client, dbname): self.embeddings, vector_store_config=field_vector_store_config, ) - if ( - not table_vector_connector.vector_name_exists() - or not field_vector_connector.vector_name_exists() - ): + if not table_vector_connector.vector_name_exists(): from dbgpt.rag.assembler.db_schema import DBSchemaAssembler chunk_parameters = ChunkParameters( From d78acc1f1ce362cc59522ef3b4384b97f06e2d95 Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Thu, 12 Dec 2024 17:53:00 +0800 Subject: [PATCH 05/11] feat(rdbsummary-wide-table): surport chatdashboard --- .env.template | 1 + dbgpt/_private/config.py | 1 + dbgpt/app/scene/chat_dashboard/chat.py | 8 +------- dbgpt/app/scene/chat_db/auto_execute/chat.py | 8 +------- dbgpt/rag/assembler/db_schema.py | 9 +++++---- dbgpt/rag/assembler/tests/test_embedding_assembler.py | 3 +-- dbgpt/rag/summary/db_summary_client.py | 1 + scripts/examples/load_examples.sh | 7 +++++++ 8 files changed, 18 insertions(+), 20 deletions(-) diff --git a/.env.template b/.env.template index cb96c9dcb..f8e3aaed8 100644 --- a/.env.template +++ b/.env.template @@ -66,6 +66,7 @@ QUANTIZE_8bit=True #** EMBEDDING SETTINGS **# #*******************************************************************# EMBEDDING_MODEL=text2vec +EMBEDDING_MODEL_MAX_SEQ_LEN=512 #EMBEDDING_MODEL=m3e-large #EMBEDDING_MODEL=bge-large-en #EMBEDDING_MODEL=bge-large-zh diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 14496df4f..e6e124871 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -264,6 +264,7 @@ def __init__(self) -> None: # EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") + self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)) # Rerank model configuration self.RERANK_MODEL = os.getenv("RERANK_MODEL") self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH") diff --git a/dbgpt/app/scene/chat_dashboard/chat.py b/dbgpt/app/scene/chat_dashboard/chat.py index bb35de7cc..25962aadd 100644 --- a/dbgpt/app/scene/chat_dashboard/chat.py +++ b/dbgpt/app/scene/chat_dashboard/chat.py @@ -61,13 +61,7 @@ async def generate_input_values(self) -> Dict: client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: - table_infos = await blocking_func_to_async( - self._executor, - client.get_db_summary, - self.db_name, - self.current_user_input, - self.top_k, - ) + table_infos = await client.aget_db_summary(self.db_name, self.current_user_input, self.top_k) print("dashboard vector find tables:{}", table_infos) except Exception as e: print("db summary find error!" + str(e)) diff --git a/dbgpt/app/scene/chat_db/auto_execute/chat.py b/dbgpt/app/scene/chat_db/auto_execute/chat.py index 69b3d493b..528180bd8 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/chat.py +++ b/dbgpt/app/scene/chat_db/auto_execute/chat.py @@ -55,13 +55,7 @@ async def generate_input_values(self) -> Dict: table_infos = None try: with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"): - table_infos = await blocking_func_to_async( - self._executor, - client.get_db_summary, - self.db_name, - self.current_user_input, - CFG.KNOWLEDGE_SEARCH_TOP_SIZE, - ) + table_infos = await client.aget_db_summary(self.db_name, self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE) except Exception as e: print("db summary find error!" + str(e)) if not table_infos: diff --git a/dbgpt/rag/assembler/db_schema.py b/dbgpt/rag/assembler/db_schema.py index 241d74e83..cd71f470d 100644 --- a/dbgpt/rag/assembler/db_schema.py +++ b/dbgpt/rag/assembler/db_schema.py @@ -1,4 +1,5 @@ """DBSchemaAssembler.""" +import os from typing import Any, List, Optional from dbgpt.core import Chunk, Embeddings @@ -41,6 +42,7 @@ def __init__( chunk_parameters: Optional[ChunkParameters] = None, embedding_model: Optional[str] = None, embeddings: Optional[Embeddings] = None, + max_seq_length: int = 512, **kwargs: Any, ) -> None: """Initialize with Embedding Assembler arguments. @@ -81,10 +83,6 @@ def __init__( self._field_vector_store_connector.vector_store_config.embedding_fn = ( embeddings ) - max_seq_length = 512 - current_embeddings = self._table_vector_store_connector.current_embeddings - if current_embeddings: - max_seq_length = current_embeddings.client.max_seq_length # type: ignore knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length) super().__init__( knowledge=knowledge, @@ -101,6 +99,7 @@ def load_from_connection( chunk_parameters: Optional[ChunkParameters] = None, embedding_model: Optional[str] = None, embeddings: Optional[Embeddings] = None, + max_seq_length: int = 512, ) -> "DBSchemaAssembler": """Load document embedding into vector store from path. @@ -113,6 +112,7 @@ def load_from_connection( chunking. embedding_model: (Optional[str]) Embedding model to use. embeddings: (Optional[Embeddings]) Embeddings to use. + max_seq_length: Embedding model max sequence length Returns: DBSchemaAssembler """ @@ -123,6 +123,7 @@ def load_from_connection( embedding_model=embedding_model, chunk_parameters=chunk_parameters, embeddings=embeddings, + max_seq_length=max_seq_length, ) def get_chunks(self) -> List[Chunk]: diff --git a/dbgpt/rag/assembler/tests/test_embedding_assembler.py b/dbgpt/rag/assembler/tests/test_embedding_assembler.py index 10d5f04d1..a2e3c9679 100644 --- a/dbgpt/rag/assembler/tests/test_embedding_assembler.py +++ b/dbgpt/rag/assembler/tests/test_embedding_assembler.py @@ -56,14 +56,12 @@ def mock_embedding_factory(): @pytest.fixture def mock_table_vector_store_connector(): mock_connector = MagicMock(spec=VectorStoreConnector) - mock_connector.current_embeddings.client.max_seq_length = 10 return mock_connector @pytest.fixture def mock_field_vector_store_connector(): mock_connector = MagicMock(spec=VectorStoreConnector) - mock_connector.current_embeddings.client.max_seq_length = 10 return mock_connector @@ -85,5 +83,6 @@ def test_load_knowledge( embeddings=mock_embedding_factory.create(), table_vector_store_connector=mock_table_vector_store_connector, field_vector_store_connector=mock_field_vector_store_connector, + max_seq_length=10 ) assert len(assembler._chunks) > 1 diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 32febb4b8..23cd23bb9 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -157,6 +157,7 @@ def init_db_profile(self, db_summary_client, dbname): table_vector_store_connector=table_vector_connector, field_vector_store_connector=field_vector_connector, chunk_parameters=chunk_parameters, + max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN ) if len(db_assembler.get_chunks()) > 0: diff --git a/scripts/examples/load_examples.sh b/scripts/examples/load_examples.sh index 0a829bdac..b01dedba3 100755 --- a/scripts/examples/load_examples.sh +++ b/scripts/examples/load_examples.sh @@ -15,6 +15,7 @@ fi DEFAULT_DB_FILE="DB-GPT/pilot/data/default_sqlite.db" DEFAULT_SQL_FILE="DB-GPT/docker/examples/sqls/*_sqlite.sql" DB_FILE="$WORK_DIR/pilot/data/default_sqlite.db" +WIDE_DB_FILE="$WORK_DIR/pilot/data/wide_sqlite.db" SQL_FILE="" usage () { @@ -61,6 +62,12 @@ if [ -n $SQL_FILE ];then sqlite3 $DB_FILE < "$file" done + for file in $WORK_DIR/docker/examples/sqls/*_sqlite_wide.sql + do + echo "execute sql file: $file" + sqlite3 $WIDE_DB_FILE < "$file" + done + else echo "Execute SQL file ${SQL_FILE}" sqlite3 $DB_FILE < $SQL_FILE From aa33e7151b5dc24b35f02173d681c07c1dcbf48e Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Thu, 12 Dec 2024 17:56:20 +0800 Subject: [PATCH 06/11] feat(rdbsummary-wide-table): add wide table case --- .../case_3_order_wide_table_sqlite_wide.sql | 317 ++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql diff --git a/docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql b/docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql new file mode 100644 index 000000000..094d7f1f0 --- /dev/null +++ b/docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql @@ -0,0 +1,317 @@ +CREATE TABLE order_wide_table ( + + -- order_base + order_id TEXT, -- 订单ID + order_no TEXT, -- 订单编号 + parent_order_no TEXT, -- 父订单编号 + order_type INTEGER, -- 订单类型:1实物2虚拟3混合 + order_status INTEGER, -- 订单状态 + order_source TEXT, -- 订单来源 + order_source_detail TEXT, -- 订单来源详情 + create_time DATETIME, -- 创建时间 + pay_time DATETIME, -- 支付时间 + finish_time DATETIME, -- 完成时间 + close_time DATETIME, -- 关闭时间 + cancel_time DATETIME, -- 取消时间 + cancel_reason TEXT, -- 取消原因 + order_remark TEXT, -- 订单备注 + seller_remark TEXT, -- 卖家备注 + buyer_remark TEXT, -- 买家备注 + is_deleted INTEGER, -- 是否删除 + delete_time DATETIME, -- 删除时间 + order_ip TEXT, -- 下单IP + order_platform TEXT, -- 下单平台 + order_device TEXT, -- 下单设备 + order_app_version TEXT, -- APP版本号 + + -- order_amount + currency TEXT, -- 货币类型 + exchange_rate REAL, -- 汇率 + original_amount REAL, -- 原始金额 + discount_amount REAL, -- 优惠金额 + coupon_amount REAL, -- 优惠券金额 + points_amount REAL, -- 积分抵扣金额 + shipping_amount REAL, -- 运费 + insurance_amount REAL, -- 保价费 + tax_amount REAL, -- 税费 + tariff_amount REAL, -- 关税 + payment_amount REAL, -- 实付金额 + commission_amount REAL, -- 佣金金额 + platform_fee REAL, -- 平台费用 + seller_income REAL, -- 卖家实收 + payment_currency TEXT, -- 支付货币 + payment_exchange_rate REAL, -- 支付汇率 + + -- user_info + user_id TEXT, -- 用户ID + user_name TEXT, -- 用户名 + user_nickname TEXT, -- 用户昵称 + user_level INTEGER, -- 用户等级 + user_type INTEGER, -- 用户类型 + register_time DATETIME, -- 注册时间 + register_source TEXT, -- 注册来源 + mobile TEXT, -- 手机号 + mobile_area TEXT, -- 手机号区号 + email TEXT, -- 邮箱 + is_vip INTEGER, -- 是否VIP + vip_level INTEGER, -- VIP等级 + vip_expire_time DATETIME, -- VIP过期时间 + user_age INTEGER, -- 用户年龄 + user_gender INTEGER, -- 用户性别 + user_birthday DATE, -- 用户生日 + user_avatar TEXT, -- 用户头像 + user_province TEXT, -- 用户所在省 + user_city TEXT, -- 用户所在市 + user_district TEXT, -- 用户所在区 + last_login_time DATETIME, -- 最后登录时间 + last_login_ip TEXT, -- 最后登录IP + user_credit_score INTEGER, -- 用户信用分 + total_order_count INTEGER, -- 历史订单数 + total_order_amount REAL, -- 历史订单金额 + + -- product_info + product_id TEXT, -- 商品ID + product_code TEXT, -- 商品编码 + product_name TEXT, -- 商品名称 + product_short_name TEXT, -- 商品短名称 + product_type INTEGER, -- 商品类型 + product_status INTEGER, -- 商品状态 + category_id TEXT, -- 类目ID + category_name TEXT, -- 类目名称 + category_path TEXT, -- 类目路径 + brand_id TEXT, -- 品牌ID + brand_name TEXT, -- 品牌名称 + brand_english_name TEXT, -- 品牌英文名 + seller_id TEXT, -- 卖家ID + seller_name TEXT, -- 卖家名称 + seller_type INTEGER, -- 卖家类型 + shop_id TEXT, -- 店铺ID + shop_name TEXT, -- 店铺名称 + product_price REAL, -- 商品价格 + market_price REAL, -- 市场价 + cost_price REAL, -- 成本价 + wholesale_price REAL, -- 批发价 + product_quantity INTEGER, -- 商品数量 + product_unit TEXT, -- 商品单位 + product_weight REAL, -- 商品重量(克) + product_volume REAL, -- 商品体积(cm³) + product_spec TEXT, -- 商品规格 + product_color TEXT, -- 商品颜色 + product_size TEXT, -- 商品尺寸 + product_material TEXT, -- 商品材质 + product_origin TEXT, -- 商品产地 + product_shelf_life INTEGER, -- 保质期(天) + manufacture_date DATE, -- 生产日期 + expiry_date DATE, -- 过期日期 + batch_number TEXT, -- 批次号 + product_barcode TEXT, -- 商品条码 + warehouse_id TEXT, -- 发货仓库ID + warehouse_name TEXT, -- 发货仓库名称 + + -- address_info + receiver_name TEXT, -- 收货人姓名 + receiver_mobile TEXT, -- 收货人手机 + receiver_tel TEXT, -- 收货人电话 + receiver_email TEXT, -- 收货人邮箱 + receiver_country TEXT, -- 国家 + receiver_province TEXT, -- 省份 + receiver_city TEXT, -- 城市 + receiver_district TEXT, -- 区县 + receiver_street TEXT, -- 街道 + receiver_address TEXT, -- 详细地址 + receiver_zip TEXT, -- 邮编 + address_type INTEGER, -- 地址类型 + is_default INTEGER, -- 是否默认地址 + longitude REAL, -- 经度 + latitude REAL, -- 纬度 + address_label TEXT, -- 地址标签 + + -- shipping_info + shipping_type INTEGER, -- 配送方式 + shipping_method TEXT, -- 配送方式名称 + shipping_company TEXT, -- 快递公司 + shipping_company_code TEXT, -- 快递公司编码 + shipping_no TEXT, -- 快递单号 + shipping_time DATETIME, -- 发货时间 + shipping_remark TEXT, -- 发货备注 + expect_receive_time DATETIME, -- 预计送达时间 + receive_time DATETIME, -- 收货时间 + sign_type INTEGER, -- 签收类型 + shipping_status INTEGER, -- 物流状态 + tracking_url TEXT, -- 物流跟踪URL + is_free_shipping INTEGER, -- 是否包邮 + shipping_insurance REAL, -- 运费险金额 + shipping_distance REAL, -- 配送距离 + delivered_time DATETIME, -- 送达时间 + delivery_staff_id TEXT, -- 配送员ID + delivery_staff_name TEXT, -- 配送员姓名 + delivery_staff_mobile TEXT, -- 配送员电话 + + -- payment_info + payment_id TEXT, -- 支付ID + payment_no TEXT, -- 支付单号 + payment_type INTEGER, -- 支付方式 + payment_method TEXT, -- 支付方式名称 + payment_status INTEGER, -- 支付状态 + payment_platform TEXT, -- 支付平台 + transaction_id TEXT, -- 交易流水号 + payment_time DATETIME, -- 支付时间 + payment_account TEXT, -- 支付账号 + payment_bank TEXT, -- 支付银行 + payment_card_type TEXT, -- 支付卡类型 + payment_card_no TEXT, -- 支付卡号 + payment_scene TEXT, -- 支付场景 + payment_client_ip TEXT, -- 支付IP + payment_device TEXT, -- 支付设备 + payment_remark TEXT, -- 支付备注 + payment_voucher TEXT, -- 支付凭证 + + -- promotion_info + promotion_id TEXT, -- 活动ID + promotion_name TEXT, -- 活动名称 + promotion_type INTEGER, -- 活动类型 + promotion_desc TEXT, -- 活动描述 + promotion_start_time DATETIME, -- 活动开始时间 + promotion_end_time DATETIME, -- 活动结束时间 + coupon_id TEXT, -- 优惠券ID + coupon_code TEXT, -- 优惠券码 + coupon_type INTEGER, -- 优惠券类型 + coupon_name TEXT, -- 优惠券名称 + coupon_desc TEXT, -- 优惠券描述 + points_used INTEGER, -- 使用积分 + points_gained INTEGER, -- 获得积分 + points_multiple REAL, -- 积分倍率 + is_first_order INTEGER, -- 是否首单 + is_new_customer INTEGER, -- 是否新客 + marketing_channel TEXT, -- 营销渠道 + marketing_source TEXT, -- 营销来源 + referral_code TEXT, -- 推荐码 + referral_user_id TEXT, -- 推荐人ID + + -- after_sale_info + refund_id TEXT, -- 退款ID + refund_no TEXT, -- 退款单号 + refund_type INTEGER, -- 退款类型 + refund_status INTEGER, -- 退款状态 + refund_reason TEXT, -- 退款原因 + refund_desc TEXT, -- 退款描述 + refund_time DATETIME, -- 退款时间 + refund_amount REAL, -- 退款金额 + return_shipping_no TEXT, -- 退货快递单号 + return_shipping_company TEXT, -- 退货快递公司 + return_shipping_time DATETIME, -- 退货时间 + refund_evidence TEXT, -- 退款凭证 + complaint_id TEXT, -- 投诉ID + complaint_type INTEGER, -- 投诉类型 + complaint_status INTEGER, -- 投诉状态 + complaint_content TEXT, -- 投诉内容 + complaint_time DATETIME, -- 投诉时间 + complaint_handle_time DATETIME, -- 投诉处理时间 + complaint_handle_result TEXT, -- 投诉处理结果 + evaluation_score INTEGER, -- 评价分数 + evaluation_content TEXT, -- 评价内容 + evaluation_time DATETIME, -- 评价时间 + evaluation_reply TEXT, -- 评价回复 + evaluation_reply_time DATETIME, -- 评价回复时间 + evaluation_images TEXT, -- 评价图片 + evaluation_videos TEXT, -- 评价视频 + is_anonymous INTEGER, -- 是否匿名评价 + + -- invoice_info + invoice_type INTEGER, -- 发票类型 + invoice_title TEXT, -- 发票抬头 + invoice_content TEXT, -- 发票内容 + tax_no TEXT, -- 税号 + invoice_amount REAL, -- 发票金额 + invoice_status INTEGER, -- 发票状态 + invoice_time DATETIME, -- 开票时间 + invoice_number TEXT, -- 发票号码 + invoice_code TEXT, -- 发票代码 + company_name TEXT, -- 单位名称 + company_address TEXT, -- 单位地址 + company_tel TEXT, -- 单位电话 + company_bank TEXT, -- 开户银行 + company_account TEXT, -- 银行账号 + + -- delivery_time_info + expect_delivery_time DATETIME, -- 期望配送时间 + delivery_period_type INTEGER, -- 配送时段类型 + delivery_period_start TEXT, -- 配送时段开始 + delivery_period_end TEXT, -- 配送时段结束 + delivery_priority INTEGER, -- 配送优先级 + + -- tag_info + order_tags TEXT, -- 订单标签 + user_tags TEXT, -- 用户标签 + product_tags TEXT, -- 商品标签 + risk_level INTEGER, -- 风险等级 + risk_tags TEXT, -- 风险标签 + business_tags TEXT, -- 业务标签 + + -- commercial_info + gross_profit REAL, -- 毛利 + gross_profit_rate REAL, -- 毛利率 + settlement_amount REAL, -- 结算金额 + settlement_time DATETIME, -- 结算时间 + settlement_cycle INTEGER, -- 结算周期 + settlement_status INTEGER, -- 结算状态 + commission_rate REAL, -- 佣金比例 + platform_service_fee REAL, -- 平台服务费 + ad_cost REAL, -- 广告费用 + promotion_cost REAL -- 推广费用 +); + +-- 插入示例数据 +INSERT INTO order_wide_table ( + -- 基础订单信息 + order_id, order_no, order_type, order_status, create_time, order_source, + -- 订单金额 + original_amount, payment_amount, shipping_amount, + -- 用户信息 + user_id, user_name, user_level, mobile, + -- 商品信息 + product_id, product_name, product_quantity, product_price, + -- 收货信息 + receiver_name, receiver_mobile, receiver_address, + -- 物流信息 + shipping_no, shipping_status, + -- 支付信息 + payment_type, payment_status, + -- 营销信息 + promotion_id, coupon_amount, + -- 发票信息 + invoice_type, invoice_title +) VALUES +( + 'ORD20240101001', 'NO20240101001', 1, 2, '2024-01-01 10:00:00', 'APP', + 199.99, 188.88, 10.00, + 'USER001', '张三', 2, '13800138000', + 'PRD001', 'iPhone 15 手机壳', 2, 89.99, + '李四', '13900139000', '北京市朝阳区XX路XX号', + 'SF123456789', 1, + 1, 1, + 'PROM001', 20.00, + 1, '个人' +), +( + 'ORD20240101002', 'NO20240101002', 1, 1, '2024-01-01 11:00:00', 'H5', + 299.99, 279.99, 0.00, + 'USER002', '王五', 3, '13700137000', + 'PRD002', 'AirPods Pro 保护套', 1, 299.99, + '赵六', '13600136000', '上海市浦东新区XX路XX号', + 'YT987654321', 2, + 2, 2, + 'PROM002', 10.00, + 2, '上海科技有限公司' +), +( + 'ORD20240101003', 'NO20240101003', 2, 3, '2024-01-01 12:00:00', 'WEB', + 1999.99, 1899.99, 0.00, + 'USER003', '陈七', 4, '13500135000', + 'PRD003', 'MacBook Pro 电脑包', 1, 1999.99, + '孙八', '13400134000', '广州市天河区XX路XX号', + 'JD123123123', 3, + 3, 1, + 'PROM003', 100.00, + 1, '个人' +); From 2149e2a6dc50fa2ac8edd9ec1cc53ff139d69629 Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Thu, 12 Dec 2024 19:25:04 +0800 Subject: [PATCH 07/11] chore --- dbgpt/_private/config.py | 4 +++- dbgpt/app/scene/chat_dashboard/chat.py | 4 +++- dbgpt/app/scene/chat_db/auto_execute/chat.py | 4 +++- dbgpt/rag/assembler/db_schema.py | 1 - dbgpt/rag/assembler/tests/test_embedding_assembler.py | 2 +- dbgpt/rag/summary/db_summary_client.py | 2 +- 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index e6e124871..55427a12f 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -264,7 +264,9 @@ def __init__(self) -> None: # EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") - self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)) + self.EMBEDDING_MODEL_MAX_SEQ_LEN = int( + os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512) + ) # Rerank model configuration self.RERANK_MODEL = os.getenv("RERANK_MODEL") self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH") diff --git a/dbgpt/app/scene/chat_dashboard/chat.py b/dbgpt/app/scene/chat_dashboard/chat.py index 25962aadd..d905e7a61 100644 --- a/dbgpt/app/scene/chat_dashboard/chat.py +++ b/dbgpt/app/scene/chat_dashboard/chat.py @@ -61,7 +61,9 @@ async def generate_input_values(self) -> Dict: client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: - table_infos = await client.aget_db_summary(self.db_name, self.current_user_input, self.top_k) + table_infos = await client.aget_db_summary( + self.db_name, self.current_user_input, self.top_k + ) print("dashboard vector find tables:{}", table_infos) except Exception as e: print("db summary find error!" + str(e)) diff --git a/dbgpt/app/scene/chat_db/auto_execute/chat.py b/dbgpt/app/scene/chat_db/auto_execute/chat.py index 528180bd8..144f4d862 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/chat.py +++ b/dbgpt/app/scene/chat_db/auto_execute/chat.py @@ -55,7 +55,9 @@ async def generate_input_values(self) -> Dict: table_infos = None try: with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"): - table_infos = await client.aget_db_summary(self.db_name, self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE) + table_infos = await client.aget_db_summary( + self.db_name, self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE + ) except Exception as e: print("db summary find error!" + str(e)) if not table_infos: diff --git a/dbgpt/rag/assembler/db_schema.py b/dbgpt/rag/assembler/db_schema.py index cd71f470d..03763130c 100644 --- a/dbgpt/rag/assembler/db_schema.py +++ b/dbgpt/rag/assembler/db_schema.py @@ -1,5 +1,4 @@ """DBSchemaAssembler.""" -import os from typing import Any, List, Optional from dbgpt.core import Chunk, Embeddings diff --git a/dbgpt/rag/assembler/tests/test_embedding_assembler.py b/dbgpt/rag/assembler/tests/test_embedding_assembler.py index a2e3c9679..3d1cba3ad 100644 --- a/dbgpt/rag/assembler/tests/test_embedding_assembler.py +++ b/dbgpt/rag/assembler/tests/test_embedding_assembler.py @@ -83,6 +83,6 @@ def test_load_knowledge( embeddings=mock_embedding_factory.create(), table_vector_store_connector=mock_table_vector_store_connector, field_vector_store_connector=mock_field_vector_store_connector, - max_seq_length=10 + max_seq_length=10, ) assert len(assembler._chunks) > 1 diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 23cd23bb9..8f236c71f 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -157,7 +157,7 @@ def init_db_profile(self, db_summary_client, dbname): table_vector_store_connector=table_vector_connector, field_vector_store_connector=field_vector_connector, chunk_parameters=chunk_parameters, - max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN + max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN, ) if len(db_assembler.get_chunks()) > 0: From fa4a988e15ee30a54a9114e57d3b2150c602f17b Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Sun, 15 Dec 2024 18:19:00 +0800 Subject: [PATCH 08/11] chore(rdb_summary-wide_table)) --- dbgpt/app/scene/chat_dashboard/chat.py | 8 +- dbgpt/app/scene/chat_db/auto_execute/chat.py | 8 +- .../app/scene/chat_db/professional_qa/chat.py | 8 +- dbgpt/rag/assembler/db_schema.py | 18 ++++- dbgpt/rag/operators/db_schema.py | 17 +++- dbgpt/rag/retriever/db_schema.py | 80 ++++++++----------- dbgpt/rag/retriever/tests/test_db_struct.py | 4 +- dbgpt/rag/summary/db_summary_client.py | 45 +---------- dbgpt/rag/summary/rdbms_db_summary.py | 11 +-- dbgpt/util/chat_util.py | 42 +++++++--- examples/rag/db_schema_rag_example.py | 4 +- 11 files changed, 123 insertions(+), 122 deletions(-) diff --git a/dbgpt/app/scene/chat_dashboard/chat.py b/dbgpt/app/scene/chat_dashboard/chat.py index d905e7a61..bb35de7cc 100644 --- a/dbgpt/app/scene/chat_dashboard/chat.py +++ b/dbgpt/app/scene/chat_dashboard/chat.py @@ -61,8 +61,12 @@ async def generate_input_values(self) -> Dict: client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: - table_infos = await client.aget_db_summary( - self.db_name, self.current_user_input, self.top_k + table_infos = await blocking_func_to_async( + self._executor, + client.get_db_summary, + self.db_name, + self.current_user_input, + self.top_k, ) print("dashboard vector find tables:{}", table_infos) except Exception as e: diff --git a/dbgpt/app/scene/chat_db/auto_execute/chat.py b/dbgpt/app/scene/chat_db/auto_execute/chat.py index 144f4d862..69b3d493b 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/chat.py +++ b/dbgpt/app/scene/chat_db/auto_execute/chat.py @@ -55,8 +55,12 @@ async def generate_input_values(self) -> Dict: table_infos = None try: with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"): - table_infos = await client.aget_db_summary( - self.db_name, self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE + table_infos = await blocking_func_to_async( + self._executor, + client.get_db_summary, + self.db_name, + self.current_user_input, + CFG.KNOWLEDGE_SEARCH_TOP_SIZE, ) except Exception as e: print("db summary find error!" + str(e)) diff --git a/dbgpt/app/scene/chat_db/professional_qa/chat.py b/dbgpt/app/scene/chat_db/professional_qa/chat.py index 0349e8f13..6049bf76c 100644 --- a/dbgpt/app/scene/chat_db/professional_qa/chat.py +++ b/dbgpt/app/scene/chat_db/professional_qa/chat.py @@ -55,8 +55,12 @@ async def generate_input_values(self) -> Dict: if self.db_name: client = DBSummaryClient(system_app=CFG.SYSTEM_APP) try: - table_infos = await client.aget_db_summary( - dbname=self.db_name, query=self.current_user_input, topk=self.top_k + table_infos = await blocking_func_to_async( + self._executor, + client.get_db_summary, + self.db_name, + self.current_user_input, + self.top_k, ) except Exception as e: print("db summary find error!" + str(e)) diff --git a/dbgpt/rag/assembler/db_schema.py b/dbgpt/rag/assembler/db_schema.py index 03763130c..ba029aaa4 100644 --- a/dbgpt/rag/assembler/db_schema.py +++ b/dbgpt/rag/assembler/db_schema.py @@ -1,4 +1,5 @@ """DBSchemaAssembler.""" +import os from typing import Any, List, Optional from dbgpt.core import Chunk, Embeddings @@ -10,6 +11,7 @@ from ..embedding.embedding_factory import DefaultEmbeddingFactory from ..knowledge.datasource import DatasourceKnowledge from ..retriever.db_schema import DBSchemaRetriever +from ...storage.vector_store.base import VectorStoreConfig class DBSchemaAssembler(BaseAssembler): @@ -37,7 +39,7 @@ def __init__( self, connector: BaseConnector, table_vector_store_connector: VectorStoreConnector, - field_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector = None, chunk_parameters: Optional[ChunkParameters] = None, embedding_model: Optional[str] = None, embeddings: Optional[Embeddings] = None, @@ -58,7 +60,17 @@ def __init__( """ self._connector = connector self._table_vector_store_connector = table_vector_store_connector - self._field_vector_store_connector = field_vector_store_connector + + field_vector_store_config = VectorStoreConfig( + name=table_vector_store_connector.vector_store_config.name + "_field" + ) + self._field_vector_store_connector = field_vector_store_connector or VectorStoreConnector.from_default( + os.getenv( + "VECTOR_STORE_TYPE", "Chroma" + ), + self._table_vector_store_connector.current_embeddings, + vector_store_config=field_vector_store_config, + ) self._embedding_model = embedding_model if self._embedding_model and not embeddings: @@ -94,7 +106,7 @@ def load_from_connection( cls, connector: BaseConnector, table_vector_store_connector: VectorStoreConnector, - field_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector = None, chunk_parameters: Optional[ChunkParameters] = None, embedding_model: Optional[str] = None, embeddings: Optional[Embeddings] = None, diff --git a/dbgpt/rag/operators/db_schema.py b/dbgpt/rag/operators/db_schema.py index 15c6986d6..53ecf1ab0 100644 --- a/dbgpt/rag/operators/db_schema.py +++ b/dbgpt/rag/operators/db_schema.py @@ -1,5 +1,5 @@ """The DBSchema Retriever Operator.""" - +import os from typing import List, Optional from dbgpt.core import Chunk @@ -11,6 +11,7 @@ from ..chunk_manager import ChunkParameters from ..retriever.db_schema import DBSchemaRetriever from .assembler import AssemblerOperator +from ...storage.vector_store.base import VectorStoreConfig class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]): @@ -56,7 +57,7 @@ def __init__( self, connector: BaseConnector, table_vector_store_connector: VectorStoreConnector, - field_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector = None, chunk_parameters: Optional[ChunkParameters] = None, **kwargs ): @@ -72,7 +73,17 @@ def __init__( chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") self._chunk_parameters = chunk_parameters self._table_vector_store_connector = table_vector_store_connector - self._field_vector_store_connector = field_vector_store_connector + + field_vector_store_config = VectorStoreConfig( + name=table_vector_store_connector.vector_store_config.name + "_field" + ) + self._field_vector_store_connector = field_vector_store_connector or VectorStoreConnector.from_default( + os.getenv( + "VECTOR_STORE_TYPE", "Chroma" + ), + self._table_vector_store_connector.current_embeddings, + vector_store_config=field_vector_store_config, + ) self._connector = connector super().__init__(**kwargs) diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index 2b2700c75..f10599a21 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -1,6 +1,8 @@ """DBSchema retriever.""" +import logging from functools import reduce from typing import List, Optional, cast +from dbgpt._private.config import Config from dbgpt.core import Chunk from dbgpt.datasource.base import BaseConnector @@ -9,16 +11,19 @@ from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters -from dbgpt.util.chat_util import run_async_tasks +from dbgpt.util.chat_util import run_async_tasks, run_tasks +from dbgpt.util.executor_utils import blocking_func_to_async_no_executor +logger = logging.getLogger(__name__) +CFG = Config() class DBSchemaRetriever(BaseRetriever): """DBSchema retriever.""" def __init__( self, table_vector_store_connector: VectorStoreConnector, - field_vector_store_connector: VectorStoreConnector, + field_vector_store_connector: VectorStoreConnector = None, separator: str = "--table-field-separator--", top_k: int = 4, connector: Optional[BaseConnector] = None, @@ -115,17 +120,11 @@ def _retrieve( List[Chunk]: list of chunks """ if self._need_embeddings: - queries = [query] - candidates = [ - self._table_vector_store_connector.similar_search( - query, self._top_k, filters - ) - for query in queries - ] - return cast(List[Chunk], reduce(lambda x, y: x + y, candidates)) + return self._similarity_search(query, filters) else: - if not self._connector: - raise RuntimeError("RDBMSConnector connection is required.") + from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401 + _parse_db_summary, + ) table_summaries = _parse_db_summary(self._connector) return [Chunk(content=table_summary) for table_summary in table_summaries] @@ -159,21 +158,11 @@ async def _aretrieve( Returns: List[Chunk]: list of chunks """ - if self._need_embeddings: - candidates = [self._similarity_search(query, filters)] - result_candidates = await run_async_tasks( - tasks=candidates, concurrency_limit=3 - ) - return result_candidates[0] - else: - from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401 - _parse_db_summary, - ) - - table_summaries = await run_async_tasks( - tasks=[self._aparse_db_summary()], concurrency_limit=1 - ) - return [Chunk(content=table_summary) for table_summary in table_summaries] + return await blocking_func_to_async_no_executor( + func=self._retrieve, + query=query, + filters=filters, + ) async def _aretrieve_with_score( self, @@ -190,43 +179,40 @@ async def _aretrieve_with_score( """ return await self._aretrieve(query, filters) - async def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk: + def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk: metadata = table_chunk.metadata metadata["part"] = "field" filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()] - field_chunks = ( - await self._field_vector_store_connector.asimilar_search_with_scores( - query, self._top_k, 0, MetadataFilters(filters=filters) - ) - ) + field_chunks = self._field_vector_store_connector.similar_search_with_scores( + query, self._top_k, 0, MetadataFilters(filters=filters)) field_contents = [chunk.content for chunk in field_chunks] table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents) return table_chunk - async def _similarity_search( + def _similarity_search( self, query, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: """Similar search.""" - table_chunks = ( - await self._table_vector_store_connector.asimilar_search_with_scores( + table_chunks = self._table_vector_store_connector.similar_search_with_scores( query, self._top_k, 0, filters ) - ) + not_sep_chunks = [ chunk for chunk in table_chunks if not chunk.metadata.get("separated") ] separated_chunks = [ chunk for chunk in table_chunks if chunk.metadata.get("separated") ] - separated_result = await run_async_tasks( - tasks=[self._retrieve_field(chunk, query) for chunk in separated_chunks] - ) - return not_sep_chunks + separated_result + if not separated_chunks: + return not_sep_chunks - async def _aparse_db_summary(self) -> List[str]: - """Similar search.""" - from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary + # Create tasks list + tasks = [ + lambda c=chunk: self._retrieve_field(c, query) + for chunk in separated_chunks + ] + # Run tasks concurrently + separated_result = run_tasks(tasks, concurrency_limit=3) - if not self._connector: - raise RuntimeError("RDBMSConnector connection is required.") - return _parse_db_summary(self._connector) + # Combine and return results + return not_sep_chunks + separated_result diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index f4d9cecad..a54e3f416 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -17,14 +17,14 @@ def mock_db_connection(): @pytest.fixture def mock_table_vector_store_connector(): mock_connector = MagicMock() - mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4 + mock_connector.similar_search_with_scores.return_value = [Chunk(content="Table summary")] * 4 return mock_connector @pytest.fixture def mock_field_vector_store_connector(): mock_connector = MagicMock() - mock_connector.similar_search.return_value = [Chunk(content="Field summary")] * 4 + mock_connector.similar_search_with_scores.return_value = [Chunk(content="Field summary")] * 4 return mock_connector diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index 8f236c71f..b39eb3109 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -55,23 +55,17 @@ def get_db_summary(self, dbname, query, topk): vector_store_name = dbname + "_profile" table_vector_store_config = VectorStoreConfig(name=vector_store_name) - field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, self.embeddings, vector_store_config=table_vector_store_config, ) - field_vector_connector = VectorStoreConnector.from_default( - CFG.VECTOR_STORE_TYPE, - self.embeddings, - vector_store_config=field_vector_store_config, - ) + from dbgpt.rag.retriever.db_schema import DBSchemaRetriever retriever = DBSchemaRetriever( top_k=topk, table_vector_store_connector=table_vector_connector, - field_vector_store_connector=field_vector_connector, separator="--table-field-separator--", ) @@ -79,36 +73,6 @@ def get_db_summary(self, dbname, query, topk): ans = [d.content for d in table_docs] return ans - async def aget_db_summary(self, dbname, query, topk): - """Get user query related tables info.""" - from dbgpt.serve.rag.connector import VectorStoreConnector - from dbgpt.storage.vector_store.base import VectorStoreConfig - - vector_store_name = dbname + "_profile" - table_vector_store_config = VectorStoreConfig(name=vector_store_name) - field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") - table_vector_connector = VectorStoreConnector.from_default( - CFG.VECTOR_STORE_TYPE, - self.embeddings, - vector_store_config=table_vector_store_config, - ) - field_vector_connector = VectorStoreConnector.from_default( - CFG.VECTOR_STORE_TYPE, - self.embeddings, - vector_store_config=field_vector_store_config, - ) - from dbgpt.rag.retriever.db_schema import DBSchemaRetriever - - retriever = DBSchemaRetriever( - top_k=topk, - table_vector_store_connector=table_vector_connector, - field_vector_store_connector=field_vector_connector, - ) - - table_docs = await retriever.aretrieve(query) - ans = [d.content for d in table_docs] - return ans - def init_db_summary(self): """Initialize db summary profile.""" db_mange = CFG.local_db_manager @@ -135,17 +99,11 @@ def init_db_profile(self, db_summary_client, dbname): from dbgpt.storage.vector_store.base import VectorStoreConfig table_vector_store_config = VectorStoreConfig(name=vector_store_name) - field_vector_store_config = VectorStoreConfig(name=vector_store_name + "_field") table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, self.embeddings, vector_store_config=table_vector_store_config, ) - field_vector_connector = VectorStoreConnector.from_default( - CFG.VECTOR_STORE_TYPE, - self.embeddings, - vector_store_config=field_vector_store_config, - ) if not table_vector_connector.vector_name_exists(): from dbgpt.rag.assembler.db_schema import DBSchemaAssembler @@ -155,7 +113,6 @@ def init_db_profile(self, db_summary_client, dbname): db_assembler = DBSchemaAssembler.load_from_connection( connector=db_summary_client.db, table_vector_store_connector=table_vector_connector, - field_vector_store_connector=field_vector_connector, chunk_parameters=chunk_parameters, max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN, ) diff --git a/dbgpt/rag/summary/rdbms_db_summary.py b/dbgpt/rag/summary/rdbms_db_summary.py index c0bd35972..c2786da51 100644 --- a/dbgpt/rag/summary/rdbms_db_summary.py +++ b/dbgpt/rag/summary/rdbms_db_summary.py @@ -119,20 +119,21 @@ def _split_columns_str(columns: List[str], model_dimension: int): for element_str in columns: element_length = len(element_str) - # 如果加上当前元素的长度会超过阈值,则将当前字符串添加到结果中,并重置 + # If adding the current element's length would exceed the threshold, + # add the current string to results and reset if current_length + element_length > model_dimension: - result.append(current_string.strip()) # 去掉末尾的空格 + result.append(current_string.strip()) # Remove trailing spaces current_string = element_str current_length = element_length else: - # 如果当前字符串为空,直接添加元素 + # If current string is empty, add element directly if current_string: current_string += "," + element_str else: current_string = element_str - current_length += element_length + 1 # 加上空格的长度 + current_length += element_length + 1 # Add length of space - # 最后一段字符串 + # Handle the last string segment if current_string: result.append(current_string.strip()) diff --git a/dbgpt/util/chat_util.py b/dbgpt/util/chat_util.py index 490f21a5f..2fd7723d5 100644 --- a/dbgpt/util/chat_util.py +++ b/dbgpt/util/chat_util.py @@ -1,5 +1,6 @@ import asyncio -from typing import Any, Coroutine, List +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Coroutine, List, Callable async def llm_chat_response_nostream(chat_scene: str, **chat_param): @@ -47,13 +48,34 @@ async def _execute_task(task): def run_tasks( - tasks: List[Coroutine], + tasks: List[Callable], + concurrency_limit: int = None, ) -> List[Any]: - """Run a list of async tasks.""" - tasks_to_execute: List[Any] = tasks - - async def _gather() -> List[Any]: - return await asyncio.gather(*tasks_to_execute) - - outputs: List[Any] = asyncio.run(_gather()) - return outputs + """ + Run a list of tasks concurrently using a thread pool. + + Args: + tasks: List of callable functions to execute + concurrency_limit: Maximum number of concurrent threads (optional) + + Returns: + List of results from all tasks in the order they were submitted + """ + max_workers = concurrency_limit if concurrency_limit else None + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks and get futures + futures = [executor.submit(task) for task in tasks] + + # Collect results in order, raising any exceptions + results = [] + for future in futures: + try: + results.append(future.result()) + except Exception as e: + # Cancel any pending futures + for f in futures: + f.cancel() + raise e + + return results \ No newline at end of file diff --git a/examples/rag/db_schema_rag_example.py b/examples/rag/db_schema_rag_example.py index 7fe1ec15e..253b6c0a2 100644 --- a/examples/rag/db_schema_rag_example.py +++ b/examples/rag/db_schema_rag_example.py @@ -4,8 +4,8 @@ from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.assembler import DBSchemaAssembler from dbgpt.rag.embedding import DefaultEmbeddingFactory +from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig -from dbgpt.storage.vector_store.connector import VectorStoreConnector """DB struct rag example. pre-requirements: @@ -63,7 +63,7 @@ def _create_vector_connector(): vector_connector = _create_vector_connector() assembler = DBSchemaAssembler.load_from_connection( connector=connection, - vector_store_connector=vector_connector, + table_vector_store_connector=vector_connector ) assembler.persist() # get db schema retriever From 750e3e86295db31e2d25312d6f410cb45fffcfdd Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Sun, 15 Dec 2024 18:42:08 +0800 Subject: [PATCH 09/11] fix(rdb_summary-wide_table): fix tests --- dbgpt/rag/assembler/db_schema.py | 15 ++++++------ .../tests/test_db_struct_assembler.py | 13 +++++----- .../tests/test_embedding_assembler.py | 12 +++------- dbgpt/rag/operators/db_schema.py | 15 ++++++------ dbgpt/rag/retriever/db_schema.py | 24 +++++++++---------- dbgpt/rag/retriever/tests/test_db_struct.py | 8 +++++-- dbgpt/util/chat_util.py | 8 +++---- examples/rag/db_schema_rag_example.py | 3 +-- 8 files changed, 47 insertions(+), 51 deletions(-) diff --git a/dbgpt/rag/assembler/db_schema.py b/dbgpt/rag/assembler/db_schema.py index ba029aaa4..99507a2b2 100644 --- a/dbgpt/rag/assembler/db_schema.py +++ b/dbgpt/rag/assembler/db_schema.py @@ -6,12 +6,12 @@ from dbgpt.datasource.base import BaseConnector from ...serve.rag.connector import VectorStoreConnector +from ...storage.vector_store.base import VectorStoreConfig from ..assembler.base import BaseAssembler from ..chunk_manager import ChunkParameters from ..embedding.embedding_factory import DefaultEmbeddingFactory from ..knowledge.datasource import DatasourceKnowledge from ..retriever.db_schema import DBSchemaRetriever -from ...storage.vector_store.base import VectorStoreConfig class DBSchemaAssembler(BaseAssembler): @@ -64,12 +64,13 @@ def __init__( field_vector_store_config = VectorStoreConfig( name=table_vector_store_connector.vector_store_config.name + "_field" ) - self._field_vector_store_connector = field_vector_store_connector or VectorStoreConnector.from_default( - os.getenv( - "VECTOR_STORE_TYPE", "Chroma" - ), - self._table_vector_store_connector.current_embeddings, - vector_store_config=field_vector_store_config, + self._field_vector_store_connector = ( + field_vector_store_connector + or VectorStoreConnector.from_default( + os.getenv("VECTOR_STORE_TYPE", "Chroma"), + self._table_vector_store_connector.current_embeddings, + vector_store_config=field_vector_store_config, + ) ) self._embedding_model = embedding_model diff --git a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py index 6f8f59cf1..01a7a9e2d 100644 --- a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py +++ b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py @@ -1,5 +1,5 @@ from typing import List -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -26,7 +26,7 @@ def mock_table_vector_store_connector(): "table_name": "user", }, ) - mock_connector.asimilar_search_with_scores = AsyncMock(return_value=[chunk]) + mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk]) return mock_connector @@ -63,7 +63,7 @@ def mock_field_vector_store_connector(): "table_name": "user", }, ) - mock_connector.asimilar_search_with_scores = AsyncMock( + mock_connector.similar_search_with_scores = MagicMock( return_value=[chunk1, chunk2, chunk3] ) return mock_connector @@ -96,10 +96,9 @@ def mock_parse_db_summary() -> str: @patch.object( dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary ) -@pytest.mark.asyncio -async def test_retrieve_with_mocked_summary(dbstruct_retriever): +def test_retrieve_with_mocked_summary(dbstruct_retriever): query = "Table summary" - chunks: List[Chunk] = await dbstruct_retriever._aretrieve(query) + chunks: List[Chunk] = dbstruct_retriever._retrieve(query) assert isinstance(chunks[0], Chunk) assert chunks[0].content == ( "table_name: user\ncomment: user about dbgpt\n" @@ -108,7 +107,7 @@ async def test_retrieve_with_mocked_summary(dbstruct_retriever): ) -async def async_mock_parse_db_summary() -> str: +def async_mock_parse_db_summary() -> str: """Asynchronous patch for _parse_db_summary method.""" return ( "table_name: user\ncomment: user about dbgpt\n" diff --git a/dbgpt/rag/assembler/tests/test_embedding_assembler.py b/dbgpt/rag/assembler/tests/test_embedding_assembler.py index 3d1cba3ad..f2ac1577e 100644 --- a/dbgpt/rag/assembler/tests/test_embedding_assembler.py +++ b/dbgpt/rag/assembler/tests/test_embedding_assembler.py @@ -5,7 +5,7 @@ from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory +from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter from dbgpt.serve.rag.connector import VectorStoreConnector @@ -56,12 +56,8 @@ def mock_embedding_factory(): @pytest.fixture def mock_table_vector_store_connector(): mock_connector = MagicMock(spec=VectorStoreConnector) - return mock_connector - - -@pytest.fixture -def mock_field_vector_store_connector(): - mock_connector = MagicMock(spec=VectorStoreConnector) + mock_connector.vector_store_config.name = "table_vector_store_name" + mock_connector.current_embeddings = DefaultEmbeddings() return mock_connector @@ -70,7 +66,6 @@ def test_load_knowledge( mock_chunk_parameters, mock_embedding_factory, mock_table_vector_store_connector, - mock_field_vector_store_connector, ): mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" mock_chunk_parameters.text_splitter = RDBTextSplitter( @@ -82,7 +77,6 @@ def test_load_knowledge( chunk_parameters=mock_chunk_parameters, embeddings=mock_embedding_factory.create(), table_vector_store_connector=mock_table_vector_store_connector, - field_vector_store_connector=mock_field_vector_store_connector, max_seq_length=10, ) assert len(assembler._chunks) > 1 diff --git a/dbgpt/rag/operators/db_schema.py b/dbgpt/rag/operators/db_schema.py index 53ecf1ab0..4e642c549 100644 --- a/dbgpt/rag/operators/db_schema.py +++ b/dbgpt/rag/operators/db_schema.py @@ -7,11 +7,11 @@ from dbgpt.datasource.base import BaseConnector from dbgpt.serve.rag.connector import VectorStoreConnector +from ...storage.vector_store.base import VectorStoreConfig from ..assembler.db_schema import DBSchemaAssembler from ..chunk_manager import ChunkParameters from ..retriever.db_schema import DBSchemaRetriever from .assembler import AssemblerOperator -from ...storage.vector_store.base import VectorStoreConfig class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]): @@ -77,12 +77,13 @@ def __init__( field_vector_store_config = VectorStoreConfig( name=table_vector_store_connector.vector_store_config.name + "_field" ) - self._field_vector_store_connector = field_vector_store_connector or VectorStoreConnector.from_default( - os.getenv( - "VECTOR_STORE_TYPE", "Chroma" - ), - self._table_vector_store_connector.current_embeddings, - vector_store_config=field_vector_store_config, + self._field_vector_store_connector = ( + field_vector_store_connector + or VectorStoreConnector.from_default( + os.getenv("VECTOR_STORE_TYPE", "Chroma"), + self._table_vector_store_connector.current_embeddings, + vector_store_config=field_vector_store_config, + ) ) self._connector = connector super().__init__(**kwargs) diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index f10599a21..a10f32adc 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -1,22 +1,23 @@ """DBSchema retriever.""" import logging -from functools import reduce -from typing import List, Optional, cast -from dbgpt._private.config import Config +from typing import List, Optional +from dbgpt._private.config import Config from dbgpt.core import Chunk from dbgpt.datasource.base import BaseConnector from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker -from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary +from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters -from dbgpt.util.chat_util import run_async_tasks, run_tasks +from dbgpt.util.chat_util import run_tasks from dbgpt.util.executor_utils import blocking_func_to_async_no_executor logger = logging.getLogger(__name__) CFG = Config() + + class DBSchemaRetriever(BaseRetriever): """DBSchema retriever.""" @@ -122,9 +123,6 @@ def _retrieve( if self._need_embeddings: return self._similarity_search(query, filters) else: - from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401 - _parse_db_summary, - ) table_summaries = _parse_db_summary(self._connector) return [Chunk(content=table_summary) for table_summary in table_summaries] @@ -184,7 +182,8 @@ def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk: metadata["part"] = "field" filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()] field_chunks = self._field_vector_store_connector.similar_search_with_scores( - query, self._top_k, 0, MetadataFilters(filters=filters)) + query, self._top_k, 0, MetadataFilters(filters=filters) + ) field_contents = [chunk.content for chunk in field_chunks] table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents) return table_chunk @@ -194,8 +193,8 @@ def _similarity_search( ) -> List[Chunk]: """Similar search.""" table_chunks = self._table_vector_store_connector.similar_search_with_scores( - query, self._top_k, 0, filters - ) + query, self._top_k, 0, filters + ) not_sep_chunks = [ chunk for chunk in table_chunks if not chunk.metadata.get("separated") @@ -208,8 +207,7 @@ def _similarity_search( # Create tasks list tasks = [ - lambda c=chunk: self._retrieve_field(c, query) - for chunk in separated_chunks + lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks ] # Run tasks concurrently separated_result = run_tasks(tasks, concurrency_limit=3) diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index a54e3f416..eea774617 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -17,14 +17,18 @@ def mock_db_connection(): @pytest.fixture def mock_table_vector_store_connector(): mock_connector = MagicMock() - mock_connector.similar_search_with_scores.return_value = [Chunk(content="Table summary")] * 4 + mock_connector.similar_search_with_scores.return_value = [ + Chunk(content="Table summary") + ] * 4 return mock_connector @pytest.fixture def mock_field_vector_store_connector(): mock_connector = MagicMock() - mock_connector.similar_search_with_scores.return_value = [Chunk(content="Field summary")] * 4 + mock_connector.similar_search_with_scores.return_value = [ + Chunk(content="Field summary") + ] * 4 return mock_connector diff --git a/dbgpt/util/chat_util.py b/dbgpt/util/chat_util.py index 2fd7723d5..ffb170093 100644 --- a/dbgpt/util/chat_util.py +++ b/dbgpt/util/chat_util.py @@ -1,6 +1,6 @@ import asyncio from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Coroutine, List, Callable +from typing import Any, Callable, Coroutine, List async def llm_chat_response_nostream(chat_scene: str, **chat_param): @@ -48,8 +48,8 @@ async def _execute_task(task): def run_tasks( - tasks: List[Callable], - concurrency_limit: int = None, + tasks: List[Callable], + concurrency_limit: int = None, ) -> List[Any]: """ Run a list of tasks concurrently using a thread pool. @@ -78,4 +78,4 @@ def run_tasks( f.cancel() raise e - return results \ No newline at end of file + return results diff --git a/examples/rag/db_schema_rag_example.py b/examples/rag/db_schema_rag_example.py index 253b6c0a2..7cfbf62d8 100644 --- a/examples/rag/db_schema_rag_example.py +++ b/examples/rag/db_schema_rag_example.py @@ -62,8 +62,7 @@ def _create_vector_connector(): connection = _create_temporary_connection() vector_connector = _create_vector_connector() assembler = DBSchemaAssembler.load_from_connection( - connector=connection, - table_vector_store_connector=vector_connector + connector=connection, table_vector_store_connector=vector_connector ) assembler.persist() # get db schema retriever From 767539d0103cc24708b9132c56366e01628e4369 Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Wed, 18 Dec 2024 11:16:38 +0800 Subject: [PATCH 10/11] fix(rdb_summary-wide_table): self._field_vector_store_connector None error --- .../assembler/tests/test_db_struct_assembler.py | 1 + dbgpt/rag/retriever/db_schema.py | 14 +++++++++++++- dbgpt/rag/retriever/tests/test_db_struct.py | 1 + 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py index 01a7a9e2d..598160374 100644 --- a/dbgpt/rag/assembler/tests/test_db_struct_assembler.py +++ b/dbgpt/rag/assembler/tests/test_db_struct_assembler.py @@ -17,6 +17,7 @@ def mock_db_connection(): @pytest.fixture def mock_table_vector_store_connector(): mock_connector = MagicMock() + mock_connector.vector_store_config.name = "table_name" chunk = Chunk( content="table_name: user\ncomment: user about dbgpt", metadata={ diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index a10f32adc..1326e2385 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -1,5 +1,6 @@ """DBSchema retriever.""" import logging +import os from typing import List, Optional from dbgpt._private.config import Config @@ -9,6 +10,7 @@ from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary from dbgpt.serve.rag.connector import VectorStoreConnector +from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters from dbgpt.util.chat_util import run_tasks from dbgpt.util.executor_utils import blocking_func_to_async_no_executor @@ -102,7 +104,17 @@ def _create_temporary_connection(): self._connector = connector self._query_rewrite = query_rewrite self._table_vector_store_connector = table_vector_store_connector - self._field_vector_store_connector = field_vector_store_connector + field_vector_store_config = VectorStoreConfig( + name=table_vector_store_connector.vector_store_config.name + "_field" + ) + self._field_vector_store_connector = ( + field_vector_store_connector + or VectorStoreConnector.from_default( + os.getenv("VECTOR_STORE_TYPE", "Chroma"), + self._table_vector_store_connector.current_embeddings, + vector_store_config=field_vector_store_config, + ) + ) self._need_embeddings = False if self._table_vector_store_connector: self._need_embeddings = True diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index eea774617..f34a4070a 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -17,6 +17,7 @@ def mock_db_connection(): @pytest.fixture def mock_table_vector_store_connector(): mock_connector = MagicMock() + mock_connector.vector_store_config.name = "table_name" mock_connector.similar_search_with_scores.return_value = [ Chunk(content="Table summary") ] * 4 From c6318c86337d52ff9d18ff85b96e96d2042ffeed Mon Sep 17 00:00:00 2001 From: dongzhancai1 Date: Wed, 18 Dec 2024 11:57:41 +0800 Subject: [PATCH 11/11] fix(rdb_summary-wide_table): delete database profile --- dbgpt/rag/summary/db_summary_client.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index b39eb3109..073c072bc 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -126,16 +126,26 @@ def init_db_profile(self, db_summary_client, dbname): def delete_db_profile(self, dbname): """Delete db profile.""" vector_store_name = dbname + "_profile" + table_vector_store_name = dbname + "_profile" + field_vector_store_name = dbname + "_profile_field" from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig - vector_store_config = VectorStoreConfig(name=vector_store_name) - vector_connector = VectorStoreConnector.from_default( + table_vector_store_config = VectorStoreConfig(name=vector_store_name) + field_vector_store_config = VectorStoreConfig(name=field_vector_store_name) + table_vector_connector = VectorStoreConnector.from_default( CFG.VECTOR_STORE_TYPE, self.embeddings, - vector_store_config=vector_store_config, + vector_store_config=table_vector_store_config, ) - vector_connector.delete_vector_name(vector_store_name) + field_vector_connector = VectorStoreConnector.from_default( + CFG.VECTOR_STORE_TYPE, + self.embeddings, + vector_store_config=field_vector_store_config, + ) + + table_vector_connector.delete_vector_name(table_vector_store_name) + field_vector_connector.delete_vector_name(field_vector_store_name) logger.info(f"delete db profile {dbname} success") @staticmethod