Skip to content

Commit

Permalink
fix: fix unit test error (#2085)
Browse files Browse the repository at this point in the history
Co-authored-by: aries_ckt <[email protected]>
Co-authored-by: Appointat <[email protected]>
  • Loading branch information
3 people authored Oct 22, 2024
1 parent 6d66678 commit d9e2042
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 113 deletions.
16 changes: 8 additions & 8 deletions dbgpt/datasource/conn_tugraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""TuGraph Connector."""

import json
from typing import Dict, Generator, List, Tuple, cast
from typing import Dict, Generator, Iterator, List, cast

from .base import BaseConnector

Expand All @@ -20,7 +20,7 @@ def __init__(self, driver, graph):
self._graph = graph
self._session = None

def create_graph(self, graph_name: str) -> None:
def create_graph(self, graph_name: str) -> bool:
"""Create a new graph in the database if it doesn't already exist."""
try:
with self._driver.session(database="default") as session:
Expand All @@ -33,6 +33,8 @@ def create_graph(self, graph_name: str) -> None:
except Exception as e:
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e

return not exists

def delete_graph(self, graph_name: str) -> None:
"""Delete a graph in the database if it exists."""
with self._driver.session(database="default") as session:
Expand Down Expand Up @@ -60,20 +62,18 @@ def from_uri_db(
"`pip install neo4j`"
) from err

def get_table_names(self) -> Tuple[List[str], List[str]]:
def get_table_names(self) -> Iterator[str]:
"""Get all table names from the TuGraph by Neo4j driver."""
with self._driver.session(database=self._graph) as session:
# Run the query to get vertex labels
raw_vertex_labels: Dict[str, str] = session.run(
"CALL db.vertexLabels()"
).data()
raw_vertex_labels = session.run("CALL db.vertexLabels()").data()
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]

# Run the query to get edge labels
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
raw_edge_labels = session.run("CALL db.edgeLabels()").data()
edge_labels = [table_name["label"] for table_name in raw_edge_labels]

return vertex_labels, edge_labels
return iter(vertex_labels + edge_labels)

def get_grants(self):
"""Get grants."""
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/rag/summary/gdbms_db_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def _parse_db_summary(
table_info_summaries = None
if isinstance(conn, TuGraphConnector):
table_names = conn.get_table_names()
v_tables = table_names.get("vertex_tables", [])
e_tables = table_names.get("edge_tables", [])
v_tables = table_names.get("vertex_tables", []) # type: ignore
e_tables = table_names.get("edge_tables", []) # type: ignore
table_info_summaries = [
_parse_table_summary(conn, summary_template, table_name, "vertex")
for table_name in v_tables
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,16 @@ def _upload_plugin(self):
if len(missing_plugins):
for name in missing_plugins:
try:
from dbgpt_tugraph_plugins import (
get_plugin_binary_path, # type:ignore[import-untyped]
from dbgpt_tugraph_plugins import ( # type: ignore
get_plugin_binary_path,
)
except ImportError:
logger.error(
"dbgpt-tugraph-plugins is not installed, "
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
"https://pypi.org/simple"
)
plugin_path = get_plugin_binary_path("leiden")
plugin_path = get_plugin_binary_path("leiden") # type: ignore
with open(plugin_path, "rb") as f:
content = f.read()
content = base64.b64encode(content).decode()
Expand Down
15 changes: 12 additions & 3 deletions dbgpt/storage/knowledge_graph/community/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncGenerator, Iterator, List, Optional, Union
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Union

from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import (
Expand Down Expand Up @@ -156,7 +156,11 @@ def create_graph(self, graph_name: str) -> None:
"""Create graph."""

@abstractmethod
def create_graph_label(self) -> None:
def create_graph_label(
self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None:
"""Create a graph label.
The graph label is used to identify and distinguish different types of nodes
Expand All @@ -176,7 +180,12 @@ def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union
from typing import AsyncGenerator, Dict, Iterator, List, Literal, Optional, Tuple, Union

from dbgpt.storage.graph_store.graph import (
Direction,
Expand Down Expand Up @@ -173,6 +173,8 @@ def create_graph(self, graph_name: str):

def create_graph_label(
self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None:
"""Create a graph label.
Expand Down Expand Up @@ -201,9 +203,12 @@ def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: int | None = None,
fan: int | None = None,
limit: int | None = None,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
return self._graph_store._graph.search(subs, direct, depth, fan, limit)
Expand Down
147 changes: 75 additions & 72 deletions dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def get_community(self, community_id: str) -> Community:
@property
def graph_store(self) -> TuGraphStore:
"""Get the graph store."""
return self._graph_store
return self._graph_store # type: ignore[return-value]

def get_graph_config(self):
"""Get the graph store config."""
Expand Down Expand Up @@ -176,29 +176,23 @@ def upsert_edge(
[{self._convert_dict_to_str(edge_list)}])"""
self.graph_store.conn.run(query=relation_query)

def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None:
"""Upsert chunks."""
chunks_list = list(chunks)
if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
chunk_list = [
{
"id": self._escape_quotes(chunk.chunk_id),
"name": self._escape_quotes(chunk.chunk_name),
"content": self._escape_quotes(chunk.content),
}
for chunk in chunks_list
]
else:
chunk_list = [
{
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks_list
]
chunk_list = [
{
"id": self._escape_quotes(chunk.chunk_id),
"name": self._escape_quotes(chunk.chunk_name),
"content": self._escape_quotes(chunk.content),
}
if isinstance(chunk, ParagraphChunk)
else {
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks
]

chunk_query = (
f"CALL db.upsertVertex("
f'"{GraphElemType.CHUNK.value}", '
Expand All @@ -207,28 +201,24 @@ def upsert_chunks(
self.graph_store.conn.run(query=chunk_query)

def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
self, documents: Iterator[Union[Vertex, ParagraphChunk]]
) -> None:
"""Upsert documents."""
documents_list = list(documents)
if documents_list and isinstance(documents_list[0], ParagraphChunk):
document_list = [
{
"id": self._escape_quotes(document.chunk_id),
"name": self._escape_quotes(document.chunk_name),
"content": "",
}
for document in documents_list
]
else:
document_list = [
{
"id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name),
"content": self._escape_quotes(document.get_prop("content")) or "",
}
for document in documents_list
]
document_list = [
{
"id": self._escape_quotes(document.chunk_id),
"name": self._escape_quotes(document.chunk_name),
"content": "",
}
if isinstance(document, ParagraphChunk)
else {
"id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name),
"content": "",
}
for document in documents
]

document_query = (
"CALL db.upsertVertex("
f'"{GraphElemType.DOCUMENT.value}", '
Expand Down Expand Up @@ -258,7 +248,7 @@ def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
self.graph_store.conn.run(query=vertex_query)
self.graph_store.conn.run(query=edge_query)

def upsert_graph(self, graph: MemoryGraph) -> None:
def upsert_graph(self, graph: Graph) -> None:
"""Add graph to the graph store.
Args:
Expand Down Expand Up @@ -362,7 +352,8 @@ def drop(self):

def create_graph(self, graph_name: str):
"""Create a graph."""
self.graph_store.conn.create_graph(graph_name=graph_name)
if not self.graph_store.conn.create_graph(graph_name=graph_name):
return

# Create the graph schema
def _format_graph_propertity_schema(
Expand Down Expand Up @@ -474,12 +465,14 @@ def create_graph_label(
(vertices) and edges in the graph.
"""
if graph_elem_type.is_vertex(): # vertex
data = json.dumps({
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
})
data = json.dumps(
{
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
}
)
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
else: # edge

Expand All @@ -505,12 +498,14 @@ def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
else:
raise ValueError("Invalid graph element type.")

data = json.dumps({
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
})
data = json.dumps(
{
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
}
)
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""

self.graph_store.conn.run(gql)
Expand All @@ -530,18 +525,16 @@ def check_label(self, graph_elem_type: GraphElemType) -> bool:
True if the label exists in the specified graph element type, otherwise
False.
"""
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
tables = self.graph_store.conn.get_table_names()

if graph_elem_type.is_vertex():
return graph_elem_type in vertex_tables
else:
return graph_elem_type in edge_tables
return graph_elem_type.value in tables

def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: int = 3,
fan: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
Expand Down Expand Up @@ -621,11 +614,17 @@ def query(self, query: str, **kwargs) -> MemoryGraph:
mg.append_edge(edge)
return mg

async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
# type: ignore[override]
# mypy: ignore-errors
async def stream_query( # type: ignore[override]
self,
query: str,
**kwargs,
) -> AsyncGenerator[Graph, None]:
"""Execute a stream query."""
from neo4j import graph

async for record in self.graph_store.conn.run_stream(query):
async for record in self.graph_store.conn.run_stream(query): # type: ignore
mg = MemoryGraph()
for key in record.keys():
value = record[key]
Expand All @@ -650,15 +649,19 @@ async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append({
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
})
formatted_path.append(
{
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
}
)
if i < len(rels):
formatted_path.append({
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
})
formatted_path.append(
{
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
}
)
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(
Vertex(
Expand Down
Loading

0 comments on commit d9e2042

Please sign in to comment.