Skip to content

Commit

Permalink
fix: Refactor graph store configuration and enable/disable graph search
Browse files Browse the repository at this point in the history
  • Loading branch information
Appointat committed Oct 30, 2024
1 parent e6f6d33 commit 1ff3184
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 66 deletions.
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ class GraphStoreConfig(BaseModel):
default=False,
description="Enable graph community summary or not.",
)
document_graph_enabled: bool = Field(
default=True,
description="Enable document graph search or not.",
)
triplet_graph_enabled: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)


class GraphStoreBase(ABC):
Expand Down
8 changes: 0 additions & 8 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary
)
self._enable_document_graph = (
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
or config.document_graph_enabled
)
self._enable_triplet_graph = (
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
or config.triplet_graph_enabled
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
Expand Down
116 changes: 71 additions & 45 deletions dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,14 +465,12 @@ def create_graph_label(
(vertices) and edges in the graph.
"""
if graph_elem_type.is_vertex(): # vertex
data = json.dumps(
{
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
}
)
data = json.dumps({
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
})
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
else: # edge

Expand All @@ -498,14 +496,12 @@ def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
else:
raise ValueError("Invalid graph element type.")

data = json.dumps(
{
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
}
)
data = json.dumps({
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
})
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""

self.graph_store.conn.run(gql)
Expand Down Expand Up @@ -544,7 +540,7 @@ def explore(
if not subs:
return MemoryGraph()

if depth < 0:
if depth <= 0:
depth = 3
depth_string = f"1..{depth}"

Expand All @@ -560,42 +556,76 @@ def explore(
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
else:
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
path_query = (
query = (
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
f"{rel}(m:{GraphElemType.ENTITY.value}) "
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
f"RETURN n {limit_string}"
f"RETURN p {limit_string}"
)
return self.query(path_query, white_list=["description"])
return self.query(query=query, white_list=["description"])
else:
graph = MemoryGraph()
check_entity_query = (
f"MATCH (n:{GraphElemType.ENTITY.value}) "
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
"RETURN n"
)

for sub in subs:
if self.query(check_entity_query):
# Query the chain from documents to chunks,
# document -> chunk -> chunk -> chunk -> ...
# document -> chunk -> ... -> chunk (-> entity, do not reach entity)
chain_query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[:{GraphElemType.INCLUDE.value}*1..{depth + 1}]->"
f"(leaf_chunk:{GraphElemType.CHUNK.value})-[:{GraphElemType.INCLUDE.value}]->"
f"(m:{GraphElemType.ENTITY.value}) "
f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} "
# "WITH n, leaf_chunk "
# f"MATCH p = (n)-[:{GraphElemType.INCLUDE.value}*1..{depth}]->(leaf_chunk:{GraphElemType.CHUNK.value}) "
"RETURN p"
)
# Filter all the properties by with_list
graph.upsert_graph(self.query(query=chain_query, white_list=[""]))

# Query the leaf chunks in the chain from documents to chunks
leaf_chunk_query = (
f"MATCH p=(n:{GraphElemType.CHUNK.value})-"
f"[r:{GraphElemType.INCLUDE.value}]->"
f"(m:{GraphElemType.ENTITY.value})"
f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} "
f"RETURN n {limit_string}"
)
graph.upsert_graph(
self.query(query=leaf_chunk_query, white_list=["content"])
)
else:
_subs_condition = " OR ".join([
f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs
])

# Query the chain from documents to chunks,
# document -> chunk -> chunk -> chunk -> ... -> chunk
chain_query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]->"
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS "
f"'{self._escape_quotes(sub)}' "
f"RETURN p"
f"(m:{GraphElemType.CHUNK.value})"
f"WHERE {_subs_condition}"
"RETURN p"
)
# Query and filter all the properties
graph_of_path = self.query(query=chain_query, white_list=[""])
graph.upsert_graph(graph_of_path)
# Filter all the properties by with_list
graph.upsert_graph(self.query(query=chain_query, white_list=[""]))

# Query the leaf chunks in the chain from documents to chunks
leaf_chunk_query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]->"
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS "
f"'{self._escape_quotes(sub)}' "
f"(m:{GraphElemType.CHUNK.value})"
f"WHERE {_subs_condition}"
f"RETURN m {limit_string}"
)
graph_of_leaf_chunk = self.query(
query=leaf_chunk_query, white_list=["content"]
graph.upsert_graph(
self.query(query=leaf_chunk_query, white_list=["content"])
)
graph.upsert_graph(graph_of_leaf_chunk)

return graph

Expand Down Expand Up @@ -663,19 +693,15 @@ async def stream_query( # type: ignore[override]
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append(
{
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
}
)
formatted_path.append({
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
})
if i < len(rels):
formatted_path.append(
{
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
}
)
formatted_path.append({
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
})
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(
Vertex(
Expand Down
35 changes: 30 additions & 5 deletions dbgpt/storage/knowledge_graph/community_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
default=0.0,
description="Recall score of community search in knowledge graph",
)
triplet_graph_enabled: bool = Field(
default=True,
description="Enable the graph search for triplets",
)
document_graph_enabled: bool = Field(
default=True,
description="Enable the graph search for documents and chunks",
)

knowledge_graph_chunk_search_top_size: int = Field(
default=5,
description="Top size of knowledge graph chunk search",
Expand Down Expand Up @@ -100,6 +109,20 @@ def __init__(self, config: CommunitySummaryKnowledgeGraphConfig):
config.community_score_threshold,
)
)
self._document_graph_enabled = bool(
(
os.environ["DOCUMENT_GRAPH_ENABLED"].lower() == "true"
if "DOCUMENT_GRAPH_ENABLED" in os.environ
else config.document_graph_enabled
)
)
self._triplet_graph_enabled = bool(
(
os.environ["TRIPLET_GRAPH_ENABLED"].lower() == "true"
if "TRIPLET_GRAPH_ENABLED" in os.environ
else config.triplet_graph_enabled
)
)
self._knowledge_graph_chunk_search_top_size = int(
os.getenv(
"KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE",
Expand Down Expand Up @@ -170,7 +193,7 @@ async def _aload_document_graph(self, chunks: List[Chunk]) -> None:
The chunks include the doc structure.
"""
if not self._graph_store.get_config().document_graph_enabled:
if not self._document_graph_enabled:
return

_chunks: List[ParagraphChunk] = [
Expand Down Expand Up @@ -201,10 +224,10 @@ async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
The chunks include the doc structure.
"""
if not self._graph_store.get_config().triplet_graph_enabled:
if not self._triplet_graph_enabled:
return

document_graph_enabled = self._graph_store.get_config().document_graph_enabled
document_graph_enabled = self._document_graph_enabled

# Extract the triplets from the chunks, and return the list of graphs
# in the same order as the input texts
Expand Down Expand Up @@ -303,10 +326,12 @@ async def asimilar_search_with_scores(
context = "\n".join(summaries) if summaries else ""

keywords: List[str] = await self._keyword_extractor.extract(text)
subgraph = None
subgraph_for_doc = None

# Local search: extract keywords and explore subgraph
triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled
document_graph_enabled = self._graph_store.get_config().document_graph_enabled
triplet_graph_enabled = self._triplet_graph_enabled
document_graph_enabled = self._document_graph_enabled

if triplet_graph_enabled:
subgraph: MemoryGraph = self._graph_store_apdater.explore(
Expand Down

0 comments on commit 1ff3184

Please sign in to comment.