diff --git a/dbgpt/datasource/conn_tugraph.py b/dbgpt/datasource/conn_tugraph.py index 191bfea87..c917365ed 100644 --- a/dbgpt/datasource/conn_tugraph.py +++ b/dbgpt/datasource/conn_tugraph.py @@ -1,7 +1,7 @@ """TuGraph Connector.""" import json -from typing import Dict, Generator, List, Tuple, cast +from typing import Dict, Generator, Iterator, List, cast from .base import BaseConnector @@ -20,7 +20,7 @@ def __init__(self, driver, graph): self._graph = graph self._session = None - def create_graph(self, graph_name: str) -> None: + def create_graph(self, graph_name: str) -> bool: """Create a new graph in the database if it doesn't already exist.""" try: with self._driver.session(database="default") as session: @@ -33,6 +33,8 @@ def create_graph(self, graph_name: str) -> None: except Exception as e: raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e + return not exists + def delete_graph(self, graph_name: str) -> None: """Delete a graph in the database if it exists.""" with self._driver.session(database="default") as session: @@ -60,20 +62,18 @@ def from_uri_db( "`pip install neo4j`" ) from err - def get_table_names(self) -> Tuple[List[str], List[str]]: + def get_table_names(self) -> Iterator[str]: """Get all table names from the TuGraph by Neo4j driver.""" with self._driver.session(database=self._graph) as session: # Run the query to get vertex labels - raw_vertex_labels: Dict[str, str] = session.run( - "CALL db.vertexLabels()" - ).data() + raw_vertex_labels = session.run("CALL db.vertexLabels()").data() vertex_labels = [table_name["label"] for table_name in raw_vertex_labels] # Run the query to get edge labels - raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data() + raw_edge_labels = session.run("CALL db.edgeLabels()").data() edge_labels = [table_name["label"] for table_name in raw_edge_labels] - return vertex_labels, edge_labels + return iter(vertex_labels + edge_labels) def get_grants(self): """Get grants.""" diff --git a/dbgpt/rag/summary/gdbms_db_summary.py b/dbgpt/rag/summary/gdbms_db_summary.py index 7a1f1ec41..2590caa2d 100644 --- a/dbgpt/rag/summary/gdbms_db_summary.py +++ b/dbgpt/rag/summary/gdbms_db_summary.py @@ -76,8 +76,8 @@ def _parse_db_summary( table_info_summaries = None if isinstance(conn, TuGraphConnector): table_names = conn.get_table_names() - v_tables = table_names.get("vertex_tables", []) - e_tables = table_names.get("edge_tables", []) + v_tables = table_names.get("vertex_tables", []) # type: ignore + e_tables = table_names.get("edge_tables", []) # type: ignore table_info_summaries = [ _parse_table_summary(conn, summary_template, table_name, "vertex") for table_name in v_tables diff --git a/dbgpt/storage/graph_store/tugraph_store.py b/dbgpt/storage/graph_store/tugraph_store.py index 45526e3fb..c20965947 100644 --- a/dbgpt/storage/graph_store/tugraph_store.py +++ b/dbgpt/storage/graph_store/tugraph_store.py @@ -141,8 +141,8 @@ def _upload_plugin(self): if len(missing_plugins): for name in missing_plugins: try: - from dbgpt_tugraph_plugins import ( - get_plugin_binary_path, # type:ignore[import-untyped] + from dbgpt_tugraph_plugins import ( # type: ignore + get_plugin_binary_path, ) except ImportError: logger.error( @@ -150,7 +150,7 @@ def _upload_plugin(self): "pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i " "https://pypi.org/simple" ) - plugin_path = get_plugin_binary_path("leiden") + plugin_path = get_plugin_binary_path("leiden") # type: ignore with open(plugin_path, "rb") as f: content = f.read() content = base64.b64encode(content).decode() diff --git a/dbgpt/storage/knowledge_graph/community/base.py b/dbgpt/storage/knowledge_graph/community/base.py index 6e4609d8b..556587c91 100644 --- a/dbgpt/storage/knowledge_graph/community/base.py +++ b/dbgpt/storage/knowledge_graph/community/base.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import AsyncGenerator, Iterator, List, Optional, Union +from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union from dbgpt.storage.graph_store.base import GraphStoreBase from dbgpt.storage.graph_store.graph import ( @@ -156,7 +156,11 @@ def create_graph(self, graph_name: str) -> None: """Create graph.""" @abstractmethod - def create_graph_label(self) -> None: + def create_graph_label( + self, + graph_elem_type: GraphElemType, + graph_properties: List[Dict[str, Union[str, bool]]], + ) -> None: """Create a graph label. The graph label is used to identify and distinguish different types of nodes @@ -176,7 +180,12 @@ def explore( self, subs: List[str], direct: Direction = Direction.BOTH, - depth: Optional[int] = None, + depth: int = 3, + fan: Optional[int] = None, + limit: Optional[int] = None, + search_scope: Optional[ + Literal["knowledge_graph", "document_graph"] + ] = "knowledge_graph", ) -> MemoryGraph: """Explore the graph from given subjects up to a depth.""" diff --git a/dbgpt/storage/knowledge_graph/community/memgraph_store_adapter.py b/dbgpt/storage/knowledge_graph/community/memgraph_store_adapter.py index 2af119958..6fd3b006d 100644 --- a/dbgpt/storage/knowledge_graph/community/memgraph_store_adapter.py +++ b/dbgpt/storage/knowledge_graph/community/memgraph_store_adapter.py @@ -2,7 +2,7 @@ import json import logging -from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union +from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Tuple, Union from dbgpt.storage.graph_store.graph import ( Direction, @@ -173,6 +173,8 @@ def create_graph(self, graph_name: str): def create_graph_label( self, + graph_elem_type: GraphElemType, + graph_properties: List[Dict[str, Union[str, bool]]], ) -> None: """Create a graph label. @@ -201,9 +203,12 @@ def explore( self, subs: List[str], direct: Direction = Direction.BOTH, - depth: int | None = None, - fan: int | None = None, - limit: int | None = None, + depth: int = 3, + fan: Optional[int] = None, + limit: Optional[int] = None, + search_scope: Optional[ + Literal["knowledge_graph", "document_graph"] + ] = "knowledge_graph", ) -> MemoryGraph: """Explore the graph from given subjects up to a depth.""" return self._graph_store._graph.search(subs, direct, depth, fan, limit) diff --git a/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py b/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py index f58a256fb..d65969d76 100644 --- a/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py +++ b/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py @@ -79,7 +79,7 @@ async def get_community(self, community_id: str) -> Community: @property def graph_store(self) -> TuGraphStore: """Get the graph store.""" - return self._graph_store + return self._graph_store # type: ignore[return-value] def get_graph_config(self): """Get the graph store config.""" @@ -176,29 +176,23 @@ def upsert_edge( [{self._convert_dict_to_str(edge_list)}])""" self.graph_store.conn.run(query=relation_query) - def upsert_chunks( - self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]] - ) -> None: + def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None: """Upsert chunks.""" - chunks_list = list(chunks) - if chunks_list and isinstance(chunks_list[0], ParagraphChunk): - chunk_list = [ - { - "id": self._escape_quotes(chunk.chunk_id), - "name": self._escape_quotes(chunk.chunk_name), - "content": self._escape_quotes(chunk.content), - } - for chunk in chunks_list - ] - else: - chunk_list = [ - { - "id": self._escape_quotes(chunk.vid), - "name": self._escape_quotes(chunk.name), - "content": self._escape_quotes(chunk.get_prop("content")), - } - for chunk in chunks_list - ] + chunk_list = [ + { + "id": self._escape_quotes(chunk.chunk_id), + "name": self._escape_quotes(chunk.chunk_name), + "content": self._escape_quotes(chunk.content), + } + if isinstance(chunk, ParagraphChunk) + else { + "id": self._escape_quotes(chunk.vid), + "name": self._escape_quotes(chunk.name), + "content": self._escape_quotes(chunk.get_prop("content")), + } + for chunk in chunks + ] + chunk_query = ( f"CALL db.upsertVertex(" f'"{GraphElemType.CHUNK.value}", ' @@ -207,28 +201,24 @@ def upsert_chunks( self.graph_store.conn.run(query=chunk_query) def upsert_documents( - self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]] + self, documents: Iterator[Union[Vertex, ParagraphChunk]] ) -> None: """Upsert documents.""" - documents_list = list(documents) - if documents_list and isinstance(documents_list[0], ParagraphChunk): - document_list = [ - { - "id": self._escape_quotes(document.chunk_id), - "name": self._escape_quotes(document.chunk_name), - "content": "", - } - for document in documents_list - ] - else: - document_list = [ - { - "id": self._escape_quotes(document.vid), - "name": self._escape_quotes(document.name), - "content": self._escape_quotes(document.get_prop("content")) or "", - } - for document in documents_list - ] + document_list = [ + { + "id": self._escape_quotes(document.chunk_id), + "name": self._escape_quotes(document.chunk_name), + "content": "", + } + if isinstance(document, ParagraphChunk) + else { + "id": self._escape_quotes(document.vid), + "name": self._escape_quotes(document.name), + "content": "", + } + for document in documents + ] + document_query = ( "CALL db.upsertVertex(" f'"{GraphElemType.DOCUMENT.value}", ' @@ -258,7 +248,7 @@ def insert_triplet(self, subj: str, rel: str, obj: str) -> None: self.graph_store.conn.run(query=vertex_query) self.graph_store.conn.run(query=edge_query) - def upsert_graph(self, graph: MemoryGraph) -> None: + def upsert_graph(self, graph: Graph) -> None: """Add graph to the graph store. Args: @@ -362,7 +352,8 @@ def drop(self): def create_graph(self, graph_name: str): """Create a graph.""" - self.graph_store.conn.create_graph(graph_name=graph_name) + if not self.graph_store.conn.create_graph(graph_name=graph_name): + return # Create the graph schema def _format_graph_propertity_schema( @@ -474,12 +465,14 @@ def create_graph_label( (vertices) and edges in the graph. """ if graph_elem_type.is_vertex(): # vertex - data = json.dumps({ - "label": graph_elem_type.value, - "type": "VERTEX", - "primary": "id", - "properties": graph_properties, - }) + data = json.dumps( + { + "label": graph_elem_type.value, + "type": "VERTEX", + "primary": "id", + "properties": graph_properties, + } + ) gql = f"""CALL db.createVertexLabelByJson('{data}')""" else: # edge @@ -505,12 +498,14 @@ def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]: else: raise ValueError("Invalid graph element type.") - data = json.dumps({ - "label": graph_elem_type.value, - "type": "EDGE", - "constraints": edge_direction(graph_elem_type), - "properties": graph_properties, - }) + data = json.dumps( + { + "label": graph_elem_type.value, + "type": "EDGE", + "constraints": edge_direction(graph_elem_type), + "properties": graph_properties, + } + ) gql = f"""CALL db.createEdgeLabelByJson('{data}')""" self.graph_store.conn.run(gql) @@ -530,18 +525,16 @@ def check_label(self, graph_elem_type: GraphElemType) -> bool: True if the label exists in the specified graph element type, otherwise False. """ - vertex_tables, edge_tables = self.graph_store.conn.get_table_names() + tables = self.graph_store.conn.get_table_names() - if graph_elem_type.is_vertex(): - return graph_elem_type in vertex_tables - else: - return graph_elem_type in edge_tables + return graph_elem_type.value in tables def explore( self, subs: List[str], direct: Direction = Direction.BOTH, depth: int = 3, + fan: Optional[int] = None, limit: Optional[int] = None, search_scope: Optional[ Literal["knowledge_graph", "document_graph"] @@ -621,11 +614,17 @@ def query(self, query: str, **kwargs) -> MemoryGraph: mg.append_edge(edge) return mg - async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]: + # type: ignore[override] + # mypy: ignore-errors + async def stream_query( # type: ignore[override] + self, + query: str, + **kwargs, + ) -> AsyncGenerator[Graph, None]: """Execute a stream query.""" from neo4j import graph - async for record in self.graph_store.conn.run_stream(query): + async for record in self.graph_store.conn.run_stream(query): # type: ignore mg = MemoryGraph() for key in record.keys(): value = record[key] @@ -650,15 +649,19 @@ async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None rels = list(record["p"].relationships) formatted_path = [] for i in range(len(nodes)): - formatted_path.append({ - "id": nodes[i]._properties["id"], - "description": nodes[i]._properties["description"], - }) + formatted_path.append( + { + "id": nodes[i]._properties["id"], + "description": nodes[i]._properties["description"], + } + ) if i < len(rels): - formatted_path.append({ - "id": rels[i]._properties["id"], - "description": rels[i]._properties["description"], - }) + formatted_path.append( + { + "id": rels[i]._properties["id"], + "description": rels[i]._properties["description"], + } + ) for i in range(0, len(formatted_path), 2): mg.upsert_vertex( Vertex( diff --git a/dbgpt/storage/knowledge_graph/community_summary.py b/dbgpt/storage/knowledge_graph/community_summary.py index c37c0d0d8..cbcacf023 100644 --- a/dbgpt/storage/knowledge_graph/community_summary.py +++ b/dbgpt/storage/knowledge_graph/community_summary.py @@ -149,25 +149,25 @@ async def aload_document(self, chunks: List[Chunk]) -> List[str]: return [chunk.chunk_id for chunk in chunks] - async def _aload_document_graph(self, chunks: List[Chunk]) -> List[str]: + async def _aload_document_graph(self, chunks: List[Chunk]) -> None: """Load the knowledge graph from the chunks. The chunks include the doc structure. """ if not self._graph_store.get_config().document_graph_enabled: - return [] + return - chunks: List[ParagraphChunk] = [ + _chunks: List[ParagraphChunk] = [ ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks ] - documment_chunk, chunks = self._load_chunks(chunks) + documment_chunk, paragraph_chunks = self._load_chunks(_chunks) # upsert the document and chunks vertices self._graph_store_apdater.upsert_documents(iter([documment_chunk])) - self._graph_store_apdater.upsert_chunks(iter(chunks)) + self._graph_store_apdater.upsert_chunks(iter(paragraph_chunks)) # upsert the document structure - for chunk_index, chunk in enumerate(chunks): + for chunk_index, chunk in enumerate(paragraph_chunks): # document -> include -> chunk if chunk.parent_is_document: self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk) @@ -177,7 +177,7 @@ async def _aload_document_graph(self, chunks: List[Chunk]) -> List[str]: # chunk -> next -> chunk if chunk_index >= 1: self._graph_store_apdater.upsert_chunk_next_chunk( - chunk=chunks[chunk_index - 1], next_chunk=chunk + chunk=paragraph_chunks[chunk_index - 1], next_chunk=chunk ) async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None: @@ -280,7 +280,7 @@ def similar_search( Return: List[Chunk]: The similar documents. """ - pass + return [] async def asimilar_search_with_scores( self, @@ -301,9 +301,6 @@ async def asimilar_search_with_scores( keywords: List[str] = await self._keyword_extractor.extract(text) # Local search: extract keywords and explore subgraph - subgraph = MemoryGraph() - subgraph_for_doc = MemoryGraph() - triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled document_graph_enabled = self._graph_store.get_config().document_graph_enabled @@ -329,9 +326,10 @@ async def asimilar_search_with_scores( limit=self._config.knowledge_graph_chunk_search_top_size, search_scope="document_graph", ) - - knowledge_graph_str = subgraph.format() - knowledge_graph_for_doc_str = subgraph_for_doc.format() + knowledge_graph_str = subgraph.format() if subgraph else "" + knowledge_graph_for_doc_str = ( + subgraph_for_doc.format() if subgraph_for_doc else "" + ) logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}") diff --git a/dbgpt/storage/knowledge_graph/knowledge_graph.py b/dbgpt/storage/knowledge_graph/knowledge_graph.py index 10d9134aa..ef2d15039 100644 --- a/dbgpt/storage/knowledge_graph/knowledge_graph.py +++ b/dbgpt/storage/knowledge_graph/knowledge_graph.py @@ -183,5 +183,5 @@ def delete_vector_name(self, index_name: str): def delete_by_ids(self, ids: str) -> List[str]: """Delete by ids.""" - self._graph_store_apdater.delete_document(chunk_ids=ids) + self._graph_store_apdater.delete_document(chunk_id=ids) return [] diff --git a/examples/rag/graph_rag_example.py b/examples/rag/graph_rag_example.py index 40d5b01e3..825a2f20e 100644 --- a/examples/rag/graph_rag_example.py +++ b/examples/rag/graph_rag_example.py @@ -88,7 +88,8 @@ def __create_community_kg_connector(): async def ask_chunk(chunk: Chunk, question) -> str: rag_template = ( - "Based on the following [Context] {context}, " "answer [Question] {question}." + "Based on the following [Context] {context}, " + "answer [Question] {question}." ) template = HumanPromptTemplate.from_template(rag_template) messages = template.format_messages(context=chunk.content, question=question) diff --git a/tests/intetration_tests/graph_store/test_tugraph_store.py b/tests/intetration_tests/graph_store/test_tugraph_store.py index d02a2ca90..a1f26b78f 100644 --- a/tests/intetration_tests/graph_store/test_tugraph_store.py +++ b/tests/intetration_tests/graph_store/test_tugraph_store.py @@ -46,9 +46,9 @@ def test_insert_and_get_triplets(tugraph_store_adapter: TuGraphStoreAdapter): assert len(triplets) == 1 -def test_query(store: TuGraphStore): +def test_query(tugraph_store_adapter: TuGraphStoreAdapter): query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3" - result = store.query(query) + result = tugraph_store_adapter.query(query) v_c = result.vertex_count e_c = result.edge_count assert v_c == 3 and e_c == 3 diff --git a/tests/unit_tests/graph/test_graph.py b/tests/unit_tests/graph/test_graph.py index b8900dc1b..200de6784 100644 --- a/tests/unit_tests/graph/test_graph.py +++ b/tests/unit_tests/graph/test_graph.py @@ -25,12 +25,12 @@ def g(): (lambda g: g.del_vertices("G", "G"), 6, 9), (lambda g: g.del_vertices("C"), 6, 7), (lambda g: g.del_vertices("A", "G"), 5, 6), - (lambda g: g.del_edges("A", "A"), 7, 7), - (lambda g: g.del_edges("A", "B"), 7, 8), + (lambda g: g.del_edges("A", "A", None), 7, 7), + (lambda g: g.del_edges("A", "B", None), 7, 8), (lambda g: g.del_edges("A", "A", "0"), 7, 8), (lambda g: g.del_edges("E", "F", "8"), 7, 8), (lambda g: g.del_edges("E", "F", "9"), 7, 9), - (lambda g: g.del_edges("E", "F", val=1), 7, 9), + (lambda g: g.del_edges("E", "F", None, val=1), 7, 9), (lambda g: g.del_edges("E", "F", "8", val=1), 7, 9), (lambda g: g.del_edges("E", "F", "9", val=1), 7, 9), (lambda g: g.del_neighbor_edges("A", Direction.IN), 7, 7),