Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat rdb summary wide table #2035

Merged
merged 15 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ QUANTIZE_8bit=True
#** EMBEDDING SETTINGS **#
#*******************************************************************#
EMBEDDING_MODEL=text2vec
EMBEDDING_MODEL_MAX_SEQ_LEN=512
#EMBEDDING_MODEL=m3e-large
#EMBEDDING_MODEL=bge-large-en
#EMBEDDING_MODEL=bge-large-zh
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ def __init__(self) -> None:

# EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(
os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)
)
# Rerank model configuration
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")
Expand Down
3 changes: 0 additions & 3 deletions dbgpt/app/scene/chat_db/professional_qa/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ async def generate_input_values(self) -> Dict:
if self.db_name:
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
# table_infos = client.get_db_summary(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
Expand Down
92 changes: 81 additions & 11 deletions dbgpt/rag/assembler/db_schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""DBSchemaAssembler."""
import os
from typing import Any, List, Optional

from dbgpt.core import Chunk
from dbgpt.core import Chunk, Embeddings
from dbgpt.datasource.base import BaseConnector

from ...serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever

Expand Down Expand Up @@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
def __init__(
self,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.

Args:
connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use.
table_vector_store_connector: VectorStoreConnector to load
and retrieve table info.
field_vector_store_connector: VectorStoreConnector to load
and retrieve field info.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
knowledge = DatasourceKnowledge(connector)
self._connector = connector
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector

field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)

self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)

if (
embeddings
and self._table_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._table_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
if (
embeddings
and self._field_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._field_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
Expand All @@ -62,23 +106,36 @@ def __init__(
def load_from_connection(
cls,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.

Args:
connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use.
table_vector_store_connector: used to load table chunks.
field_vector_store_connector: used to load field chunks
if field in table is too much.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
max_seq_length: Embedding model max sequence length
Returns:
DBSchemaAssembler
"""
return cls(
connector=connector,
index_store=index_store,
table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
embedding_model=embedding_model,
chunk_parameters=chunk_parameters,
embeddings=embeddings,
max_seq_length=max_seq_length,
)

def get_chunks(self) -> List[Chunk]:
Expand All @@ -91,7 +148,19 @@ def persist(self, **kwargs: Any) -> List[str]:
Returns:
List[str]: List of chunk ids.
"""
return self._index_store.load_document(self._chunks)
table_chunks, field_chunks = [], []
for chunk in self._chunks:
metadata = chunk.metadata
if metadata.get("separated"):
if metadata.get("part") == "table":
table_chunks.append(chunk)
else:
field_chunks.append(chunk)
else:
table_chunks.append(chunk)

self._field_vector_store_connector.load_document(field_chunks)
return self._table_vector_store_connector.load_document(table_chunks)

def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
Expand All @@ -110,5 +179,6 @@ def as_retriever(self, top_k: int = 4, **kwargs) -> DBSchemaRetriever:
top_k=top_k,
connector=self._connector,
is_embeddings=True,
index_store=self._index_store,
table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
)
151 changes: 96 additions & 55 deletions dbgpt/rag/assembler/tests/test_db_struct_assembler.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,117 @@
from unittest.mock import MagicMock
from typing import List
from unittest.mock import MagicMock, patch

import pytest

from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.chroma_store import ChromaStore
import dbgpt
from dbgpt.core import Chunk
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary


@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
)
return connect
return MagicMock()


@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
def mock_table_vector_store_connector():
mock_connector = MagicMock()
mock_connector.vector_store_config.name = "table_name"
chunk = Chunk(
content="table_name: user\ncomment: user about dbgpt",
metadata={
"field_num": 6,
"part": "table",
"separated": 1,
"table_name": "user",
},
)
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
return mock_connector


@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
def mock_field_vector_store_connector():
mock_connector = MagicMock()
chunk1 = Chunk(
content="name,age",
metadata={
"field_num": 6,
"part": "field",
"part_index": 0,
"separated": 1,
"table_name": "user",
},
)
chunk2 = Chunk(
content="address,gender",
metadata={
"field_num": 6,
"part": "field",
"part_index": 1,
"separated": 1,
"table_name": "user",
},
)
chunk3 = Chunk(
content="mail,phone",
metadata={
"field_num": 6,
"part": "field",
"part_index": 2,
"separated": 1,
"table_name": "user",
},
)
mock_connector.similar_search_with_scores = MagicMock(
return_value=[chunk1, chunk2, chunk3]
)
return mock_connector


@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=ChromaStore)
def dbstruct_retriever(
mock_db_connection,
mock_table_vector_store_connector,
mock_field_vector_store_connector,
):
return DBSchemaRetriever(
connector=mock_db_connection,
table_vector_store_connector=mock_table_vector_store_connector,
field_vector_store_connector=mock_field_vector_store_connector,
separator="--table-field-separator--",
)


@pytest.fixture
def mock_knowledge():
return MagicMock(spec=Knowledge)
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)


def test_load_knowledge(
mock_db_connection,
mock_knowledge,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = EmbeddingAssembler(
knowledge=mock_knowledge,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
index_store=mock_vector_store_connector,
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)


def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 0
Loading
Loading