diff --git a/dbgpt/rag/transformer/base.py b/dbgpt/rag/transformer/base.py index a71c2da14..35b3e022d 100644 --- a/dbgpt/rag/transformer/base.py +++ b/dbgpt/rag/transformer/base.py @@ -22,6 +22,17 @@ def drop(self): class EmbedderBase(TransformerBase, ABC): """Embedder base class.""" + @abstractmethod + async def embed(self, text: str) -> List[float]: + """Embed vector from text.""" + + @abstractmethod + async def batch_embed( + self, + texts: List[str], + ) -> List[List[float]]: + """Batch embed vectors from texts.""" + class SummarizerBase(TransformerBase, ABC): """Summarizer base class.""" diff --git a/dbgpt/rag/transformer/graph_embedder.py b/dbgpt/rag/transformer/graph_embedder.py new file mode 100644 index 000000000..6865832a6 --- /dev/null +++ b/dbgpt/rag/transformer/graph_embedder.py @@ -0,0 +1,51 @@ +"""GraphEmbedder class.""" + +import logging +from typing import List + +from dbgpt.rag.transformer.text2vector import Text2Vector +from dbgpt.storage.graph_store.graph import Graph, GraphElemType + +logger = logging.getLogger(__name__) + + +class GraphEmbedder(Text2Vector): + """GraphEmbedder class.""" + + def __init__(self): + """Initialize the GraphEmbedder""" + super().__init__() + + async def embed( + self, + text: str, + ) -> List[float]: + """Embed""" + return await super()._embed(text) + + async def batch_embed( + self, + graphs_list: List[List[Graph]], + ) -> List[List[Graph]]: + """Embed graphs from graphs in batches""" + + for graphs in graphs_list: + for graph in graphs: + for vertex in graph.vertices(): + if vertex.get_prop("vertex_type") == GraphElemType.CHUNK.value: + text = vertex.get_prop("content") + vector = await self._embed(text) + vertex.set_prop("embedding", vector) + elif vertex.get_prop("vertex_type") == GraphElemType.ENTITY.value: + vector = await self._embed(vertex.vid) + vertex.set_prop("embedding", vector) + else: + text = "" + + return graphs_list + + def truncate(self): + """""" + + def drop(self): + """""" diff --git a/dbgpt/rag/transformer/text2vector.py b/dbgpt/rag/transformer/text2vector.py index d7257001c..a81336509 100644 --- a/dbgpt/rag/transformer/text2vector.py +++ b/dbgpt/rag/transformer/text2vector.py @@ -1,10 +1,50 @@ """Text2Vector class.""" + import logging +from abc import ABC +from http import HTTPStatus +from typing import List + +import dashscope from dbgpt.rag.transformer.base import EmbedderBase logger = logging.getLogger(__name__) -class Text2Vector(EmbedderBase): +class Text2Vector(EmbedderBase, ABC): """Text2Vector class.""" + + def __init__(self): + """Initialize the Embedder""" + + async def embed(self, text: str) -> List[float]: + """Embed vector from text.""" + return await self._embed(text) + + async def batch_embed( + self, + texts: List[str], + ) -> List[List[float]]: + """Batch embed vectors from texts.""" + results = [] + for text in texts: + vector = await self._embed(text) + results.extend(vector) + return results + + async def _embed(self, text: str) -> List[float]: + """Embed vector from text.""" + resp = dashscope.TextEmbedding.call( + model = dashscope.TextEmbedding.Models.text_embedding_v3, + input = text, + dimension = 512) + embeddings = resp.output['embeddings'] + embedding = embeddings[0]['embedding'] + return list(embedding) + + def truncate(self): + """Do nothing by default.""" + + def drop(self): + """Do nothing by default.""" diff --git a/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py b/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py index 3a2d96ae5..611f6708b 100644 --- a/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py +++ b/dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py @@ -42,6 +42,10 @@ def __init__(self, graph_store: TuGraphStore): # Create the graph self.create_graph(self.graph_store.get_config().name) + # vector index create control + self._chunk_vector_index = False + self._entity_vector_index = False + async def discover_communities(self, **kwargs) -> List[str]: """Run community discovery with leiden.""" mg = self.query( @@ -145,6 +149,7 @@ def upsert_entities(self, entities: Iterator[Vertex]) -> None: "_document_id": "0", "_chunk_id": "0", "_community_id": "0", + "_embedding": entity.get_prop("embedding"), } for entity in entities ] @@ -153,8 +158,17 @@ def upsert_entities(self, entities: Iterator[Vertex]) -> None: f'"{GraphElemType.ENTITY.value}", ' f"[{self._convert_dict_to_str(entity_list)}])" ) + create_vector_index_query = ( + f"CALL db.addVertexVectorIndex(" + f'"{GraphElemType.ENTITY.value}", "_embedding", ' + "{dimension: 512})" + ) self.graph_store.conn.run(query=entity_query) + if not self._entity_vector_index : + self.graph_store.conn.run(query=create_vector_index_query) + self._entity_vector_index = True + def upsert_edge( self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str ) -> None: @@ -189,6 +203,7 @@ def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None "id": self._escape_quotes(chunk.vid), "name": self._escape_quotes(chunk.name), "content": self._escape_quotes(chunk.get_prop("content")), + "_embedding": chunk.get_prop("embedding"), } for chunk in chunks ] @@ -198,7 +213,16 @@ def upsert_chunks(self, chunks: Iterator[Union[Vertex, ParagraphChunk]]) -> None f'"{GraphElemType.CHUNK.value}", ' f"[{self._convert_dict_to_str(chunk_list)}])" ) + create_vector_index_query = ( + f"CALL db.addVertexVectorIndex(" + f'"{GraphElemType.CHUNK.value}", "_embedding", ' + "{dimension: 512})" + ) self.graph_store.conn.run(query=chunk_query) + + if not self._chunk_vector_index : + self.graph_store.conn.run(query=create_vector_index_query) + self._chunk_vector_index = True def upsert_documents( self, documents: Iterator[Union[Vertex, ParagraphChunk]] @@ -404,6 +428,7 @@ def _format_graph_property_schema( _format_graph_property_schema("name", "STRING", False), _format_graph_property_schema("_community_id", "STRING", True, True), _format_graph_property_schema("content", "STRING", True, True), + _format_graph_property_schema("_embedding", "FLOAT_VECTOR", True, False), ] self.create_graph_label( graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties @@ -415,6 +440,7 @@ def _format_graph_property_schema( _format_graph_property_schema("name", "STRING", False), _format_graph_property_schema("_community_id", "STRING", True, True), _format_graph_property_schema("description", "STRING", True, True), + _format_graph_property_schema("_embedding", "FLOAT_VECTOR", True, False), ] self.create_graph_label( graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties @@ -531,7 +557,7 @@ def check_label(self, graph_elem_type: GraphElemType) -> bool: def explore( self, - subs: List[str], + subs: Union[List[str], List[List[float]]], direct: Direction = Direction.BOTH, depth: int = 3, fan: Optional[int] = None, @@ -560,10 +586,28 @@ def explore( rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-" else: rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-" + + if all(isinstance(item, str) for item in subs): + header = f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " + else: + final_list = [] + for sub in subs: + vector = str(sub); + similarity_search = ( + f"CALL db.vertexVectorKnnSearch(" + f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, " + "{top_k:2, hnsw_ef_search:10})" + "YIELD node RETURN node.id AS id;" + ) + result_list = self.graph_store.conn.run(query=similarity_search) + final_list.extend(result_list) + id_list = [(record["id"]) for record in final_list] + header = f"WHERE n.id IN {id_list} " + query = ( f"MATCH p=(n:{GraphElemType.ENTITY.value})" f"{rel}(m:{GraphElemType.ENTITY.value}) " - f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " + f"{header}" f"RETURN p {limit_string}" ) return self.query(query=query, white_list=["description"]) @@ -583,18 +627,52 @@ def explore( graph = MemoryGraph() # Check if the entities exist in the graph + + if all(isinstance(item, str) for item in subs): + header = f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " + else: + final_list = [] + for sub in subs: + vector = str(sub); + similarity_search = ( + f"CALL db.vertexVectorKnnSearch(" + f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, " + "{top_k:2, hnsw_ef_search:10})" + "YIELD node RETURN node.id AS id" + ) + result_list = self.graph_store.conn.run(query=similarity_search) + final_list.extend(result_list) + id_list = [(record["id"]) for record in final_list] + header = f"WHERE n.id IN {id_list} " check_entity_query = ( f"MATCH (n:{GraphElemType.ENTITY.value}) " - f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} " + f"{header}" "RETURN n" ) if self.query(check_entity_query): # Query the leaf chunks in the chain from documents to chunks + if all(isinstance(item, str) for item in subs): + header = f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} " + else: + final_list = [] + for sub in subs: + vector = str(sub); + similarity_search = ( + f"CALL db.vertexVectorKnnSearch(" + f"'{GraphElemType.ENTITY.value}','_embedding', {vector}, " + "{top_k:2, hnsw_ef_search:10})" + "YIELD node RETURN node.name AS name" + ) + result_list = self.graph_store.conn.run(query=similarity_search) + final_list.extend(result_list) + name_list = [(record["name"]) for record in final_list] + header = f"WHERE n.name IN {name_list} " + leaf_chunk_query = ( f"MATCH p=(n:{GraphElemType.CHUNK.value})-" f"[r:{GraphElemType.INCLUDE.value}]->" f"(m:{GraphElemType.ENTITY.value})" - f"WHERE m.name IN {[self._escape_quotes(sub) for sub in subs]} " + f"{header} " f"RETURN n" ) graph_of_leaf_chunks = self.query( @@ -628,10 +706,25 @@ def explore( ) ) else: - _subs_condition = " OR ".join( - [f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs] - ) - + if all(isinstance(item, str) for item in subs): + _subs_condition = " OR ".join( + [f"m.content CONTAINS '{self._escape_quotes(sub)}'" for sub in subs] + ) + else: + final_list = [] + for sub in subs: + vector = str(sub); + similarity_search = ( + f"CALL db.vertexVectorKnnSearch(" + f"'{GraphElemType.CHUNK.value}','_embedding', {vector}, " + "{top_k:2, hnsw_ef_search:10})" + "YIELD node RETURN node.name AS name" + ) + result_list = self.graph_store.conn.run(query=similarity_search) + final_list.extend(result_list) + name_list = [(record["name"]) for record in final_list] + _subs_condition = f"n.name IN {name_list} " + # Query the chain from documents to chunks, # document -> chunk -> chunk -> chunk -> ... -> chunk chain_query = ( diff --git a/dbgpt/storage/knowledge_graph/community_summary.py b/dbgpt/storage/knowledge_graph/community_summary.py index 806f9df54..f8e40a1d4 100644 --- a/dbgpt/storage/knowledge_graph/community_summary.py +++ b/dbgpt/storage/knowledge_graph/community_summary.py @@ -8,6 +8,7 @@ from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import Chunk from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer +from dbgpt.rag.transformer.graph_embedder import GraphEmbedder from dbgpt.rag.transformer.graph_extractor import GraphExtractor from dbgpt.storage.knowledge_graph.base import ParagraphChunk from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore @@ -78,6 +79,10 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig): default=20, description="Batch size of parallel community building process", ) + similar_search_enabled: bool = Field( + default=False, + description="Enable the similarity search", + ) class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph): @@ -139,6 +144,11 @@ def __init__(self, config: CommunitySummaryKnowledgeGraphConfig): config.community_summary_batch_size, ) ) + self._similar_search_enabled = ( + os.environ["SIMILAR_SEARCH_ENABLED"].lower() == "true" + if "SIMILAR_SEARCH_ENABLED" in os.environ + else config.similar_search_enabled + ) def extractor_configure(name: str, cfg: VectorStoreConfig): cfg.name = name @@ -160,6 +170,8 @@ def extractor_configure(name: str, cfg: VectorStoreConfig): ), ) + self._garph_embedder = GraphEmbedder() + def community_store_configure(name: str, cfg: VectorStoreConfig): cfg.name = name cfg.embedding_fn = config.embedding_fn @@ -244,6 +256,8 @@ async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None: if not graphs_list: raise ValueError("No graphs extracted from the chunks") + graphs_list = await self._garph_embedder.batch_embed(graphs_list) + # Upsert the graphs into the graph store for idx, graphs in enumerate(graphs_list): for graph in graphs: @@ -333,7 +347,16 @@ async def asimilar_search_with_scores( ] context = "\n".join(summaries) if summaries else "" - keywords: List[str] = await self._keyword_extractor.extract(text) + # Vector similarity search + similar_search_enabled = self._similar_search_enabled + + if similar_search_enabled: + keywords: List[List[float]] = [] + vector = await self._garph_embedder.embed(text) + keywords.append(vector) + else: + keywords: List[str] = await self._keyword_extractor.extract(text) + subgraph = None subgraph_for_doc = None @@ -348,8 +371,13 @@ async def asimilar_search_with_scores( if document_graph_enabled: keywords_for_document_graph = keywords - for vertex in subgraph.vertices(): - keywords_for_document_graph.append(vertex.name) + if similar_search_enabled: + for vertex in subgraph.vertices(): + vector = await self._garph_embedder.embed(vertex.name) + keywords_for_document_graph.append(vector) + else: + for vertex in subgraph.vertices(): + keywords_for_document_graph.append(vertex.name) subgraph_for_doc = self._graph_store_apdater.explore( subs=keywords_for_document_graph, @@ -363,6 +391,7 @@ async def asimilar_search_with_scores( limit=self._knowledge_graph_chunk_search_top_size, search_scope="document_graph", ) + knowledge_graph_str = subgraph.format() if subgraph else "" knowledge_graph_for_doc_str = ( subgraph_for_doc.format() if subgraph_for_doc else "" @@ -390,6 +419,8 @@ def truncate(self) -> List[str]: self._keyword_extractor.truncate() logger.info("Truncate triplet extractor") self._graph_extractor.truncate() + logger.info("Truncate graph embedder") + self._garph_embedder.truncate() return [self._config.name] def delete_vector_name(self, index_name: str): @@ -403,6 +434,9 @@ def delete_vector_name(self, index_name: str): logger.info("Drop triplet extractor") self._graph_extractor.drop() + logger.info("Drop graph embedder") + self._garph_embedder.drop() + HYBRID_SEARCH_PT = """ =====