-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support embedding similarity search for GraphRAG #2200
base: main
Are you sure you want to change the base?
Changes from all commits
897c3da
3ea27ca
7fdcd6e
de3b351
b8ba884
24443c2
fdc0068
a2a759f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""GraphEmbedder class.""" | ||
|
||
import logging | ||
from typing import List | ||
|
||
from dbgpt.rag.transformer.text2vector import Text2Vector | ||
from dbgpt.storage.graph_store.graph import Graph, GraphElemType | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GraphEmbedder(Text2Vector): | ||
"""GraphEmbedder class.""" | ||
|
||
def __init__(self): | ||
"""Initialize the GraphEmbedder""" | ||
super().__init__() | ||
|
||
async def embed( | ||
self, | ||
text: str, | ||
) -> List[float]: | ||
"""Embed""" | ||
return await super()._embed(text) | ||
|
||
async def batch_embed( | ||
self, | ||
graphs_list: List[List[Graph]], | ||
) -> List[List[Graph]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use: |
||
"""Embed graphs from graphs in batches""" | ||
|
||
for graphs in graphs_list: | ||
for graph in graphs: | ||
for vertex in graph.vertices(): | ||
if vertex.get_prop("vertex_type") == GraphElemType.CHUNK.value: | ||
text = vertex.get_prop("content") | ||
vector = await self._embed(text) | ||
vertex.set_prop("embedding", vector) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use _embedding, since it is a prop not visible for the users (it's implicit). |
||
elif vertex.get_prop("vertex_type") == GraphElemType.ENTITY.value: | ||
vector = await self._embed(vertex.vid) | ||
vertex.set_prop("embedding", vector) | ||
Comment on lines
+40
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why convert the id to vec? Is the name of the vertext none? (I do not remember the vertext data structure very clearly. Need the reply) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Vid of the entity is the keywords of the entity. We use the vec of it to search the similiar keywords instead of using the keywords matching directly. |
||
else: | ||
text = "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove it |
||
|
||
return graphs_list | ||
|
||
def truncate(self): | ||
"""""" | ||
|
||
def drop(self): | ||
"""""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,50 @@ | ||
"""Text2Vector class.""" | ||
|
||
import logging | ||
from abc import ABC | ||
from http import HTTPStatus | ||
from typing import List | ||
|
||
import dashscope | ||
|
||
from dbgpt.rag.transformer.base import EmbedderBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Text2Vector(EmbedderBase): | ||
class Text2Vector(EmbedderBase, ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it an abstract class? |
||
"""Text2Vector class.""" | ||
|
||
def __init__(self): | ||
"""Initialize the Embedder""" | ||
|
||
async def embed(self, text: str) -> List[float]: | ||
"""Embed vector from text.""" | ||
return await self._embed(text) | ||
|
||
async def batch_embed( | ||
self, | ||
texts: List[str], | ||
) -> List[List[float]]: | ||
Comment on lines
+25
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a para called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return List[] maybe better. |
||
"""Batch embed vectors from texts.""" | ||
results = [] | ||
for text in texts: | ||
vector = await self._embed(text) | ||
results.extend(vector) | ||
Comment on lines
+31
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle the task in batch/parallel (note: And, it is |
||
return results | ||
|
||
async def _embed(self, text: str) -> List[float]: | ||
"""Embed vector from text.""" | ||
resp = dashscope.TextEmbedding.call( | ||
model = dashscope.TextEmbedding.Models.text_embedding_v3, | ||
input = text, | ||
dimension = 512) | ||
Comment on lines
+38
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is better not to fix the embedded model, which is not in line with extensibility. The solution should be to encapsulate |
||
embeddings = resp.output['embeddings'] | ||
embedding = embeddings[0]['embedding'] | ||
return list(embedding) | ||
|
||
def truncate(self): | ||
"""Do nothing by default.""" | ||
|
||
def drop(self): | ||
"""Do nothing by default.""" |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -42,6 +42,10 @@ def __init__(self, graph_store: TuGraphStore): | |||||
# Create the graph | ||||||
self.create_graph(self.graph_store.get_config().name) | ||||||
|
||||||
# vector index create control | ||||||
self._chunk_vector_index = False | ||||||
self._entity_vector_index = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. upsert_vector和upsert_chunk是会被重复调用吗?为什么需要这个变量? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 会重复调用两个upsert,所以在这里用这个变量控制,保证不会多次调用创建索引的语句 |
||||||
|
||||||
async def discover_communities(self, **kwargs) -> List[str]: | ||||||
"""Run community discovery with leiden.""" | ||||||
mg = self.query( | ||||||
|
@@ -145,6 +149,7 @@ def upsert_entities(self, entities: Iterator[Vertex]) -> None: | |||||
"_document_id": "0", | ||||||
"_chunk_id": "0", | ||||||
"_community_id": "0", | ||||||
"_embedding": entity.get_prop("embedding"), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的get_prop为什么还是embedding而不是_embedding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在构建memory graph的时候向量的key用的是embedding,后面为了避免从tugraph拿出数据的时候数据长度太长,所以给schema中的向量字段名称改为_embedding,这样可以被white_list过滤,防止返回的数据量太大,chunk同理 |
||||||
} | ||||||
for entity in entities | ||||||
] | ||||||
|
@@ -153,8 +158,17 @@ def upsert_entities(self, entities: Iterator[Vertex]) -> None: | |||||
f'"{GraphElemType.ENTITY.value}", ' | ||||||
f"[{self._convert_dict_to_str(entity_list)}])" | ||||||
) | ||||||
create_vector_index_query = ( | ||||||
f"CALL db.addVertexVectorIndex(" | ||||||
f'"{GraphElemType.ENTITY.value}", "_embedding", ' | ||||||
"{dimension: 512})" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 向量的dimension是否需要一个配置项? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
) | ||||||
self.graph_store.conn.run(query=entity_query) | ||||||
|
||||||
if not self._entity_vector_index : | ||||||
self.graph_store.conn.run(query=create_vector_index_query) | ||||||
self._entity_vector_index = True | ||||||
|
||||||
def upsert_edge( | ||||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str | ||||||
) -> None: | ||||||
|
@@ -189,6 +203,7 @@ def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None | |||||
"id": self._escape_quotes(chunk.vid), | ||||||
"name": self._escape_quotes(chunk.name), | ||||||
"content": self._escape_quotes(chunk.get_prop("content")), | ||||||
"_embedding": chunk.get_prop("embedding"), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的get_prop为什么还是embedding而不是_embedding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 和对entity的处理相似 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if isinstance(chunk, ParagraphChunk), it should also has a prop |
||||||
} | ||||||
for chunk in chunks | ||||||
] | ||||||
|
@@ -198,7 +213,16 @@ def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None | |||||
f'"{GraphElemType.CHUNK.value}", ' | ||||||
f"[{self._convert_dict_to_str(chunk_list)}])" | ||||||
) | ||||||
create_vector_index_query = ( | ||||||
f"CALL db.addVertexVectorIndex(" | ||||||
f'"{GraphElemType.CHUNK.value}", "_embedding", ' | ||||||
"{dimension: 512})" | ||||||
) | ||||||
Comment on lines
+219
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No hardcode here please. The dimension should be got from the configs. |
||||||
self.graph_store.conn.run(query=chunk_query) | ||||||
|
||||||
if not self._chunk_vector_index : | ||||||
self.graph_store.conn.run(query=create_vector_index_query) | ||||||
self._chunk_vector_index = True | ||||||
Comment on lines
+223
to
+225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove it |
||||||
|
||||||
def upsert_documents( | ||||||
self, documents: Iterator[Union[Vertex, ParagraphChunk]] | ||||||
|
@@ -404,6 +428,7 @@ def _format_graph_property_schema( | |||||
_format_graph_property_schema("name", "STRING", False), | ||||||
_format_graph_property_schema("_community_id", "STRING", True, True), | ||||||
_format_graph_property_schema("content", "STRING", True, True), | ||||||
_format_graph_property_schema("_embedding", "FLOAT_VECTOR", True, False), | ||||||
] | ||||||
self.create_graph_label( | ||||||
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties | ||||||
|
@@ -415,6 +440,7 @@ def _format_graph_property_schema( | |||||
_format_graph_property_schema("name", "STRING", False), | ||||||
_format_graph_property_schema("_community_id", "STRING", True, True), | ||||||
_format_graph_property_schema("description", "STRING", True, True), | ||||||
_format_graph_property_schema("_embedding", "FLOAT_VECTOR", True, False), | ||||||
] | ||||||
self.create_graph_label( | ||||||
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties | ||||||
|
@@ -531,7 +557,7 @@ def check_label(self, graph_elem_type: GraphElemType) -> bool: | |||||
|
||||||
def explore( | ||||||
self, | ||||||
subs: List[str], | ||||||
subs: Union[List[str], List[List[float]]], | ||||||
direct: Direction = Direction.BOTH, | ||||||
depth: int = 3, | ||||||
fan: Optional[int] = None, | ||||||
|
@@ -560,10 +586,28 @@ def explore( | |||||
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-" | ||||||
else: | ||||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-" | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add some code commendations. |
||||||
if all(isinstance(item, str) for item in subs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里判断subs是List[str]还是List[List[float]]是否需要遍历整个list才能确定? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 按照之前最开始不遍历的方法,因为都是list类型,他就分辨不出来,所以这里得进去分辨 |
||||||
header = f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " | ||||||
else: | ||||||
final_list = [] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
for sub in subs: | ||||||
vector = str(sub); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||||||
similarity_search = ( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
f"CALL db.vertexVectorKnnSearch(" | ||||||
f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, " | ||||||
"{top_k:2, hnsw_ef_search:10})" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. top_k是否需要一个配置项?这里直接指定的是2 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hnsw_ef_search这个参数指定的是? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. top_k太大了我觉得可能会影响效果,毕竟按照原来的设计,一个关键字对应一个header,这里相当于是一个向量对应两个header了;ef_search是hnsw搜索时需要指定的参数,构建时也同样需要指定参数,但是我们这里使用了默认的就没有写出来 |
||||||
"YIELD node RETURN node.id AS id;" | ||||||
) | ||||||
result_list = self.graph_store.conn.run(query=similarity_search) | ||||||
final_list.extend(result_list) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. append (✅) extend(x) |
||||||
id_list = [(record["id"]) for record in final_list] | ||||||
header = f"WHERE n.id IN {id_list} " | ||||||
|
||||||
query = ( | ||||||
f"MATCH p=(n:{GraphElemType.ENTITY.value})" | ||||||
f"{rel}(m:{GraphElemType.ENTITY.value}) " | ||||||
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " | ||||||
f"{header}" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace the name to |
||||||
f"RETURN p {limit_string}" | ||||||
) | ||||||
return self.query(query=query, white_list=["description"]) | ||||||
|
@@ -583,18 +627,52 @@ def explore( | |||||
graph = MemoryGraph() | ||||||
|
||||||
# Check if the entities exist in the graph | ||||||
|
||||||
if all(isinstance(item, str) for item in subs): | ||||||
header = f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " | ||||||
else: | ||||||
final_list = [] | ||||||
for sub in subs: | ||||||
vector = str(sub); | ||||||
similarity_search = ( | ||||||
f"CALL db.vertexVectorKnnSearch(" | ||||||
f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, " | ||||||
"{top_k:2, hnsw_ef_search:10})" | ||||||
"YIELD node RETURN node.id AS id" | ||||||
) | ||||||
result_list = self.graph_store.conn.run(query=similarity_search) | ||||||
final_list.extend(result_list) | ||||||
id_list = [(record["id"]) for record in final_list] | ||||||
header = f"WHERE n.id IN {id_list} " | ||||||
check_entity_query = ( | ||||||
f"MATCH (n:{GraphElemType.ENTITY.value}) " | ||||||
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " | ||||||
f"{header}" | ||||||
"RETURN n" | ||||||
) | ||||||
if self.query(check_entity_query): | ||||||
# Query the leaf chunks in the chain from documents to chunks | ||||||
if all(isinstance(item, str) for item in subs): | ||||||
header = f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} " | ||||||
else: | ||||||
final_list = [] | ||||||
for sub in subs: | ||||||
vector = str(sub); | ||||||
similarity_search = ( | ||||||
f"CALL db.vertexVectorKnnSearch(" | ||||||
f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, " | ||||||
"{top_k:2, hnsw_ef_search:10})" | ||||||
"YIELD node RETURN node.name AS name" | ||||||
) | ||||||
result_list = self.graph_store.conn.run(query=similarity_search) | ||||||
final_list.extend(result_list) | ||||||
name_list = [(record["name"]) for record in final_list] | ||||||
header = f"WHERE n.name IN {name_list} " | ||||||
|
||||||
leaf_chunk_query = ( | ||||||
f"MATCH p=(n:{GraphElemType.CHUNK.value})-" | ||||||
f"[r:{GraphElemType.INCLUDE.value}]->" | ||||||
f"(m:{GraphElemType.ENTITY.value})" | ||||||
f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} " | ||||||
f"{header} " | ||||||
f"RETURN n" | ||||||
) | ||||||
graph_of_leaf_chunks = self.query( | ||||||
|
@@ -628,10 +706,25 @@ def explore( | |||||
) | ||||||
) | ||||||
else: | ||||||
_subs_condition = " OR ".join( | ||||||
[f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs] | ||||||
) | ||||||
|
||||||
if all(isinstance(item, str) for item in subs): | ||||||
_subs_condition = " OR ".join( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace the name to |
||||||
[f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs] | ||||||
) | ||||||
else: | ||||||
final_list = [] | ||||||
for sub in subs: | ||||||
vector = str(sub); | ||||||
similarity_search = ( | ||||||
f"CALL db.vertexVectorKnnSearch(" | ||||||
f"'{GraphElemType.CHUNK.value}','_embedding', {vector}, " | ||||||
"{top_k:2, hnsw_ef_search:10})" | ||||||
"YIELD node RETURN node.name AS name" | ||||||
) | ||||||
result_list = self.graph_store.conn.run(query=similarity_search) | ||||||
final_list.extend(result_list) | ||||||
name_list = [(record["name"]) for record in final_list] | ||||||
_subs_condition = f"n.name IN {name_list} " | ||||||
|
||||||
# Query the chain from documents to chunks, | ||||||
# document -> chunk -> chunk -> chunk -> ... -> chunk | ||||||
chain_query = ( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The abstract methods have been decleared in
TransformerBase
. Remove them.