-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
363 additions
and
177 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,90 @@ | ||
"""GraphEmbedder class.""" | ||
|
||
import asyncio | ||
import logging | ||
from typing import List | ||
|
||
from dbgpt.rag.transformer.text2vector import Text2Vector | ||
from tenacity import retry, stop_after_attempt, wait_fixed | ||
|
||
from dbgpt.core.interface.embeddings import Embeddings | ||
from dbgpt.rag.transformer.base import EmbedderBase | ||
from dbgpt.storage.graph_store.graph import Graph, GraphElemType | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GraphEmbedder(Text2Vector): | ||
class GraphEmbedder(EmbedderBase): | ||
"""GraphEmbedder class.""" | ||
|
||
def __init__(self): | ||
"""Initialize the GraphEmbedder""" | ||
def __init__(self, embedding_fn: Embeddings): | ||
"""Initialize the GraphEmbedder.""" | ||
self.embedding_fn = embedding_fn | ||
super().__init__() | ||
|
||
async def embed( | ||
self, | ||
text: str, | ||
) -> List[float]: | ||
"""Embed""" | ||
return await super()._embed(text) | ||
"""Embed.""" | ||
return await self.embedding_fn.aembed_query(text) | ||
|
||
async def batch_embed( | ||
self, | ||
graphs_list: List[List[Graph]], | ||
batch_size: int = 1, | ||
) -> List[List[Graph]]: | ||
"""Embed graphs from graphs in batches""" | ||
|
||
"""Embed graphs from graphs in batches.""" | ||
for graphs in graphs_list: | ||
for graph in graphs: | ||
|
||
texts = [] | ||
vectors = [] | ||
|
||
# Get the text from graph | ||
for vertex in graph.vertices(): | ||
if vertex.get_prop("vertex_type") == GraphElemType.CHUNK.value: | ||
text = vertex.get_prop("content") | ||
vector = await self._embed(text) | ||
vertex.set_prop("embedding", vector) | ||
texts.append(vertex.get_prop("content")) | ||
elif vertex.get_prop("vertex_type") == GraphElemType.ENTITY.value: | ||
vector = await self._embed(vertex.vid) | ||
vertex.set_prop("embedding", vector) | ||
texts.append(vertex.vid) | ||
else: | ||
text = "" | ||
texts.append(" ") | ||
|
||
n_texts = len(texts) | ||
|
||
# Batch embedding | ||
for batch_idx in range(0, n_texts, batch_size): | ||
start_idx = batch_idx | ||
end_idx = min(start_idx + batch_size, n_texts) | ||
batch_texts = texts[start_idx:end_idx] | ||
|
||
# Create tasks | ||
embedding_tasks = [(self._embed(text)) for text in batch_texts] | ||
|
||
# Process embedding in parallel | ||
batch_results = await asyncio.gather( | ||
*(task for task in embedding_tasks), return_exceptions=True | ||
) | ||
|
||
# Place results in the correct positions | ||
for idx, vector in enumerate(batch_results): | ||
if isinstance(vector, Exception): | ||
raise RuntimeError(f"Failed to embed text{idx}") | ||
else: | ||
vectors.append(vector) | ||
|
||
# Push vectors back into Graph | ||
for vertex, vector in zip(graph.vertices(), vectors): | ||
vertex.set_prop("_embedding", vector) | ||
|
||
return graphs_list | ||
|
||
|
||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) | ||
async def _embed(self, text: str) -> List: | ||
"""Inner embed.""" | ||
return await self.embedding_fn.aembed_query(text) | ||
|
||
def truncate(self): | ||
"""""" | ||
"""Do nothing by default.""" | ||
|
||
def drop(self): | ||
"""""" | ||
"""Do nothing by default.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.