Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Jant1L committed Dec 28, 2024
1 parent a2a759f commit b1cf2ad
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 177 deletions.
5 changes: 4 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -159,17 +159,20 @@ VECTOR_STORE_TYPE=Chroma
GRAPH_STORE_TYPE=TuGraph
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0

GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
SIMILARITY_SEARCH_ENABLED=False # enable the similarity search for entities and chunks

KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text
COMMUNITY_SUMMARY_BATCH_SIZE=20 # the batch size of parallel community summary process

KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE=20 # the batch size of embedding from the text
### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data

Expand Down
5 changes: 3 additions & 2 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ async def embed(self, text: str) -> List[float]:
@abstractmethod
async def batch_embed(
self,
texts: List[str],
) -> List[List[float]]:
graphs_list: List[List],
batch_size: int = 1,
) -> List[List]:
"""Batch embed vectors from texts."""


Expand Down
77 changes: 58 additions & 19 deletions dbgpt/rag/transformer/graph_embedder.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,90 @@
"""GraphEmbedder class."""

import asyncio
import logging
from typing import List

from dbgpt.rag.transformer.text2vector import Text2Vector
from tenacity import retry, stop_after_attempt, wait_fixed

from dbgpt.core.interface.embeddings import Embeddings
from dbgpt.rag.transformer.base import EmbedderBase
from dbgpt.storage.graph_store.graph import Graph, GraphElemType

logger = logging.getLogger(__name__)


class GraphEmbedder(Text2Vector):
class GraphEmbedder(EmbedderBase):
"""GraphEmbedder class."""

def __init__(self):
"""Initialize the GraphEmbedder"""
def __init__(self, embedding_fn: Embeddings):
"""Initialize the GraphEmbedder."""
self.embedding_fn = embedding_fn
super().__init__()

async def embed(
self,
text: str,
) -> List[float]:
"""Embed"""
return await super()._embed(text)
"""Embed."""
return await self.embedding_fn.aembed_query(text)

async def batch_embed(
self,
graphs_list: List[List[Graph]],
batch_size: int = 1,
) -> List[List[Graph]]:
"""Embed graphs from graphs in batches"""

"""Embed graphs from graphs in batches."""
for graphs in graphs_list:
for graph in graphs:

texts = []
vectors = []

# Get the text from graph
for vertex in graph.vertices():
if vertex.get_prop("vertex_type") == GraphElemType.CHUNK.value:
text = vertex.get_prop("content")
vector = await self._embed(text)
vertex.set_prop("embedding", vector)
texts.append(vertex.get_prop("content"))
elif vertex.get_prop("vertex_type") == GraphElemType.ENTITY.value:
vector = await self._embed(vertex.vid)
vertex.set_prop("embedding", vector)
texts.append(vertex.vid)
else:
text = ""
texts.append(" ")

n_texts = len(texts)

# Batch embedding
for batch_idx in range(0, n_texts, batch_size):
start_idx = batch_idx
end_idx = min(start_idx + batch_size, n_texts)
batch_texts = texts[start_idx:end_idx]

# Create tasks
embedding_tasks = [(self._embed(text)) for text in batch_texts]

# Process embedding in parallel
batch_results = await asyncio.gather(
*(task for task in embedding_tasks), return_exceptions=True
)

# Place results in the correct positions
for idx, vector in enumerate(batch_results):
if isinstance(vector, Exception):
raise RuntimeError(f"Failed to embed text{idx}")
else:
vectors.append(vector)

# Push vectors back into Graph
for vertex, vector in zip(graph.vertices(), vectors):
vertex.set_prop("_embedding", vector)

return graphs_list


@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
async def _embed(self, text: str) -> List:
"""Inner embed."""
return await self.embedding_fn.aembed_query(text)

def truncate(self):
""""""
"""Do nothing by default."""

def drop(self):
""""""
"""Do nothing by default."""
35 changes: 10 additions & 25 deletions dbgpt/rag/transformer/text2vector.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,32 @@
"""Text2Vector class."""

import logging
from abc import ABC
from http import HTTPStatus
from typing import List

import dashscope

from dbgpt.core.interface.embeddings import Embeddings
from dbgpt.rag.transformer.base import EmbedderBase

logger = logging.getLogger(__name__)


class Text2Vector(EmbedderBase, ABC):
class Text2Vector(EmbedderBase):
"""Text2Vector class."""

def __init__(self):
"""Initialize the Embedder"""
def __init__(self, embedding_fn: Embeddings):
"""Initialize the Embedder."""
self.embedding_fn = embedding_fn
super().__init__()

async def embed(self, text: str) -> List[float]:
"""Embed vector from text."""
return await self._embed(text)
return await self.embedding_fn.aembed_query(text)

async def batch_embed(
self,
texts: List[str],
text_list: List[List],
batch_size: int = 1,
) -> List[List[float]]:
"""Batch embed vectors from texts."""
results = []
for text in texts:
vector = await self._embed(text)
results.extend(vector)
return results

async def _embed(self, text: str) -> List[float]:
"""Embed vector from text."""
resp = dashscope.TextEmbedding.call(
model = dashscope.TextEmbedding.Models.text_embedding_v3,
input = text,
dimension = 512)
embeddings = resp.output['embeddings']
embedding = embeddings[0]['embedding']
return list(embedding)
"""Embed texts from graphs in batches."""

def truncate(self):
"""Do nothing by default."""
Expand Down
30 changes: 30 additions & 0 deletions dbgpt/storage/graph_store/tugraph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ class TuGraphStoreConfig(GraphStoreConfig):
"/dbgpt-tugraph-plugins/tree/master/cpp"
),
)
similarity_search_enabled: bool = Field(
default=False,
description="Enable the similarity search",
)
similarity_search_topk: int = Field(
default=5,
description="Topk of knowledge graph extract",
)
similarity_search_score_threshold: float = Field(
default=0.3,
description="Recall score of knowledge graph extract",
)


class TuGraphStore(GraphStoreBase):
Expand All @@ -83,6 +95,11 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary
)
self._similarity_search_enabled = (
os.environ["SIMILARITY_SEARCH_ENABLED"].lower() == "true"
if "SIMILARITY_SEARCH_ENABLED" in os.environ
else config.similarity_search_enabled
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
Expand All @@ -98,6 +115,19 @@ def __init__(self, config: TuGraphStoreConfig) -> None:
db_name=config.name,
)

self._similarity_search_topk = int(
os.getenv(
"KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE",
config.similarity_search_topk,
)
)
self._similarity_search_score_threshold = float(
os.getenv(
"KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_RECALL_SCORE",
config.similarity_search_score_threshold,
)
)

def get_config(self) -> TuGraphStoreConfig:
"""Get the TuGraph store config."""
return self._config
Expand Down
Loading

0 comments on commit b1cf2ad

Please sign in to comment.