Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support embedding similarity search for GraphRAG #2200

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ def drop(self):
class EmbedderBase(TransformerBase, ABC):
"""Embedder base class."""

@abstractmethod
async def embed(self, text: str) -> List[float]:
"""Embed vector from text."""

@abstractmethod
async def batch_embed(
self,
texts: List[str],
) -> List[List[float]]:
"""Batch embed vectors from texts."""

Comment on lines +25 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The abstract methods have been decleared in TransformerBase. Remove them.


class SummarizerBase(TransformerBase, ABC):
"""Summarizer base class."""
Expand Down
51 changes: 51 additions & 0 deletions dbgpt/rag/transformer/graph_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""GraphEmbedder class."""

import logging
from typing import List

from dbgpt.rag.transformer.text2vector import Text2Vector
from dbgpt.storage.graph_store.graph import Graph, GraphElemType

logger = logging.getLogger(__name__)


class GraphEmbedder(Text2Vector):
"""GraphEmbedder class."""

def __init__(self):
"""Initialize the GraphEmbedder"""
super().__init__()

async def embed(
self,
text: str,
) -> List[float]:
"""Embed"""
return await super()._embed(text)

async def batch_embed(
self,
graphs_list: List[List[Graph]],
) -> List[List[Graph]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use:
async def batch_embed(
self,
graphs_list: List[Graph],
) -> List[Graph]:

"""Embed graphs from graphs in batches"""

for graphs in graphs_list:
for graph in graphs:
for vertex in graph.vertices():
if vertex.get_prop("vertex_type") == GraphElemType.CHUNK.value:
text = vertex.get_prop("content")
vector = await self._embed(text)
vertex.set_prop("embedding", vector)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use _embedding, since it is a prop not visible for the users (it's implicit).

elif vertex.get_prop("vertex_type") == GraphElemType.ENTITY.value:
vector = await self._embed(vertex.vid)
vertex.set_prop("embedding", vector)
Comment on lines +40 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why convert the id to vec? Is the name of the vertext none? (I do not remember the vertext data structure very clearly. Need the reply)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vid of the entity is the keywords of the entity. We use the vec of it to search the similiar keywords instead of using the keywords matching directly.

else:
text = ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it


return graphs_list

def truncate(self):
""""""

def drop(self):
""""""
42 changes: 41 additions & 1 deletion dbgpt/rag/transformer/text2vector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
"""Text2Vector class."""

import logging
from abc import ABC
from http import HTTPStatus
from typing import List

import dashscope

from dbgpt.rag.transformer.base import EmbedderBase

logger = logging.getLogger(__name__)


class Text2Vector(EmbedderBase):
class Text2Vector(EmbedderBase, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it an abstract class?

"""Text2Vector class."""

def __init__(self):
"""Initialize the Embedder"""

async def embed(self, text: str) -> List[float]:
"""Embed vector from text."""
return await self._embed(text)

async def batch_embed(
self,
texts: List[str],
) -> List[List[float]]:
Comment on lines +25 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a para called batch_size, and donot forget to modify the config var consistantly.
Refer to: _triplet_extraction_batch_size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return List[] maybe better.

"""Batch embed vectors from texts."""
results = []
for text in texts:
vector = await self._embed(text)
results.extend(vector)
Comment on lines +31 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle the task in batch/parallel (note: asyncio)

And, it is append rather than extend.

return results

async def _embed(self, text: str) -> List[float]:
"""Embed vector from text."""
resp = dashscope.TextEmbedding.call(
model = dashscope.TextEmbedding.Models.text_embedding_v3,
input = text,
dimension = 512)
Comment on lines +38 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better not to fix the embedded model, which is not in line with extensibility.

The solution should be to encapsulate embedding_fn (which is an LLM embedding model, not dashscope's model) instead of calling a pkg. Refer: embedding_fn: Optional[Embeddings] = Field( ... from BuiltinKnowledgeGraphConfig.

embeddings = resp.output['embeddings']
embedding = embeddings[0]['embedding']
return list(embedding)

def truncate(self):
"""Do nothing by default."""

def drop(self):
"""Do nothing by default."""
109 changes: 101 additions & 8 deletions dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self, graph_store: TuGraphStore):
# Create the graph
self.create_graph(self.graph_store.get_config().name)

# vector index create control
self._chunk_vector_index = False
self._entity_vector_index = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upsert_vector和upsert_chunk是会被重复调用吗?为什么需要这个变量?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会重复调用两个upsert,所以在这里用这个变量控制,保证不会多次调用创建索引的语句


async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery with leiden."""
mg = self.query(
Expand Down Expand Up @@ -145,6 +149,7 @@ def upsert_entities(self, entities: Iterator[Vertex]) -> None:
"_document_id": "0",
"_chunk_id": "0",
"_community_id": "0",
"_embedding": entity.get_prop("embedding"),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的get_prop为什么还是embedding而不是_embedding?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在构建memory graph的时候向量的key用的是embedding,后面为了避免从tugraph拿出数据的时候数据长度太长,所以给schema中的向量字段名称改为_embedding,这样可以被white_list过滤,防止返回的数据量太大,chunk同理

}
for entity in entities
]
Expand All @@ -153,8 +158,17 @@ def upsert_entities(self, entities: Iterator[Vertex]) -> None:
f'"{GraphElemType.ENTITY.value}", '
f"[{self._convert_dict_to_str(entity_list)}])"
)
create_vector_index_query = (
f"CALL db.addVertexVectorIndex("
f'"{GraphElemType.ENTITY.value}", "_embedding", '
"{dimension: 512})"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

向量的dimension是否需要一个配置项?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
self.graph_store.conn.run(query=entity_query)

if not self._entity_vector_index :
self.graph_store.conn.run(query=create_vector_index_query)
self._entity_vector_index = True

def upsert_edge(
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
) -> None:
Expand Down Expand Up @@ -189,6 +203,7 @@ def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
"_embedding": chunk.get_prop("embedding"),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的get_prop为什么还是embedding而不是_embedding?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和对entity的处理相似

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(chunk, ParagraphChunk), it should also has a prop _embedding.

}
for chunk in chunks
]
Expand All @@ -198,7 +213,16 @@ def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None
f'"{GraphElemType.CHUNK.value}", '
f"[{self._convert_dict_to_str(chunk_list)}])"
)
create_vector_index_query = (
f"CALL db.addVertexVectorIndex("
f'"{GraphElemType.CHUNK.value}", "_embedding", '
"{dimension: 512})"
)
Comment on lines +219 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No hardcode here please. The dimension should be got from the configs.

self.graph_store.conn.run(query=chunk_query)

if not self._chunk_vector_index :
self.graph_store.conn.run(query=create_vector_index_query)
self._chunk_vector_index = True
Comment on lines +223 to +225
Copy link
Contributor

@Appointat Appointat Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it


def upsert_documents(
self, documents: Iterator[Union[Vertex, ParagraphChunk]]
Expand Down Expand Up @@ -404,6 +428,7 @@ def _format_graph_property_schema(
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_community_id", "STRING", True, True),
_format_graph_property_schema("content", "STRING", True, True),
_format_graph_property_schema("_embedding", "FLOAT_VECTOR", True, False),
]
self.create_graph_label(
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
Expand All @@ -415,6 +440,7 @@ def _format_graph_property_schema(
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_community_id", "STRING", True, True),
_format_graph_property_schema("description", "STRING", True, True),
_format_graph_property_schema("_embedding", "FLOAT_VECTOR", True, False),
]
self.create_graph_label(
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
Expand Down Expand Up @@ -531,7 +557,7 @@ def check_label(self, graph_elem_type: GraphElemType) -> bool:

def explore(
self,
subs: List[str],
subs: Union[List[str], List[List[float]]],
direct: Direction = Direction.BOTH,
depth: int = 3,
fan: Optional[int] = None,
Expand Down Expand Up @@ -560,10 +586,28 @@ def explore(
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
else:
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some code commendations.

if all(isinstance(item, str) for item in subs):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里判断subs是List[str]还是List[List[float]]是否需要遍历整个list才能确定?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照之前最开始不遍历的方法,因为都是list类型,他就分辨不出来,所以这里得进去分辨

header = f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
else:
final_list = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
final_list = []
similar_entities: List[Some_Type] = []

for sub in subs:
vector = str(sub);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove ;

similarity_search = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
similarity_search = (
similarity_retrieval_query = (

f"CALL db.vertexVectorKnnSearch("
f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, "
"{top_k:2, hnsw_ef_search:10})"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k是否需要一个配置项?这里直接指定的是2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hnsw_ef_search这个参数指定的是?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k太大了我觉得可能会影响效果,毕竟按照原来的设计,一个关键字对应一个header,这里相当于是一个向量对应两个header了;ef_search是hnsw搜索时需要指定的参数,构建时也同样需要指定参数,但是我们这里使用了默认的就没有写出来

"YIELD node RETURN node.id AS id;"
)
result_list = self.graph_store.conn.run(query=similarity_search)
final_list.extend(result_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

append (✅) extend(x)

id_list = [(record["id"]) for record in final_list]
header = f"WHERE n.id IN {id_list} "

query = (
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
f"{rel}(m:{GraphElemType.ENTITY.value}) "
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
f"{header}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace the name to conditional_statement

f"RETURN p {limit_string}"
)
return self.query(query=query, white_list=["description"])
Expand All @@ -583,18 +627,52 @@ def explore(
graph = MemoryGraph()

# Check if the entities exist in the graph

if all(isinstance(item, str) for item in subs):
header = f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
else:
final_list = []
for sub in subs:
vector = str(sub);
similarity_search = (
f"CALL db.vertexVectorKnnSearch("
f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, "
"{top_k:2, hnsw_ef_search:10})"
"YIELD node RETURN node.id AS id"
)
result_list = self.graph_store.conn.run(query=similarity_search)
final_list.extend(result_list)
id_list = [(record["id"]) for record in final_list]
header = f"WHERE n.id IN {id_list} "
check_entity_query = (
f"MATCH (n:{GraphElemType.ENTITY.value}) "
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
f"{header}"
"RETURN n"
)
if self.query(check_entity_query):
# Query the leaf chunks in the chain from documents to chunks
if all(isinstance(item, str) for item in subs):
header = f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} "
else:
final_list = []
for sub in subs:
vector = str(sub);
similarity_search = (
f"CALL db.vertexVectorKnnSearch("
f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, "
"{top_k:2, hnsw_ef_search:10})"
"YIELD node RETURN node.name AS name"
)
result_list = self.graph_store.conn.run(query=similarity_search)
final_list.extend(result_list)
name_list = [(record["name"]) for record in final_list]
header = f"WHERE n.name IN {name_list} "

leaf_chunk_query = (
f"MATCH p=(n:{GraphElemType.CHUNK.value})-"
f"[r:{GraphElemType.INCLUDE.value}]->"
f"(m:{GraphElemType.ENTITY.value})"
f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} "
f"{header} "
f"RETURN n"
)
graph_of_leaf_chunks = self.query(
Expand Down Expand Up @@ -628,10 +706,25 @@ def explore(
)
)
else:
_subs_condition = " OR ".join(
[f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs]
)

if all(isinstance(item, str) for item in subs):
_subs_condition = " OR ".join(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace the name to conditional_statement

[f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs]
)
else:
final_list = []
for sub in subs:
vector = str(sub);
similarity_search = (
f"CALL db.vertexVectorKnnSearch("
f"'{GraphElemType.CHUNK.value}','_embedding', {vector}, "
"{top_k:2, hnsw_ef_search:10})"
"YIELD node RETURN node.name AS name"
)
result_list = self.graph_store.conn.run(query=similarity_search)
final_list.extend(result_list)
name_list = [(record["name"]) for record in final_list]
_subs_condition = f"n.name IN {name_list} "

# Query the chain from documents to chunks,
# document -> chunk -> chunk -> chunk -> ... -> chunk
chain_query = (
Expand Down
Loading
Loading