Skip to content

Commit

Permalink
fix(rag): Fix db schema aretriever bug (#1755)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Jul 30, 2024
1 parent 55c8b39 commit 25d7d94
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 30 deletions.
20 changes: 0 additions & 20 deletions .github/workflows/sync-docs.yaml

This file was deleted.

6 changes: 4 additions & 2 deletions dbgpt/rag/retriever/db_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def _aretrieve(
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
return result_candidates
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
else:
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
_parse_db_summary,
Expand All @@ -177,7 +177,9 @@ async def _aretrieve(
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
concurrency_limit=1,
)
return [Chunk(content=table_summary) for table_summary in table_summaries]
return [
Chunk(content=table_summary) for table_summary in table_summaries[0]
]

async def _aretrieve_with_score(
self,
Expand Down
22 changes: 14 additions & 8 deletions dbgpt/rag/retriever/tests/test_db_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,35 @@ def mock_vector_store_connector():


@pytest.fixture
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
return DBSchemaRetriever(
connector=mock_db_connection,
index_store=mock_vector_store_connector,
)


def mock_parse_db_summary() -> str:
def mock_parse_db_summary(conn) -> List[str]:
"""Patch _parse_db_summary method."""
return "Table summary"
return ["Table summary"]


# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
def test_retrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"


async def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return "Table summary"
@pytest.mark.asyncio
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"

0 comments on commit 25d7d94

Please sign in to comment.