Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance the triplets extraction in the knowledge graph by the batch size #2091

Merged
merged 20 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
89db275
feat: Improve triplet extraction batch size and handling
Appointat Oct 23, 2024
f8e3ed1
feat: Improve triplet extraction batch size and handling
Appointat Oct 23, 2024
a57029e
refactor: Add batch_extract method to ExtractorBase
Appointat Oct 28, 2024
3fc7640
refactor: refactor: Add batch_extract method to GraphExtractor
Appointat Oct 28, 2024
fee90cc
refactor: Add batch_extract method to LLMExtractor
Appointat Oct 28, 2024
ccd2cdf
refactor: Refactor CommunitySummaryKnowledgeGraph batch extraction me…
Appointat Oct 28, 2024
3f65e49
refactor: Update knowledge graph extraction batch size
Appointat Oct 29, 2024
a253542
refactor: Update knowledge graph extraction batch size
Appointat Oct 29, 2024
c565600
Refactor batch extraction methods in GraphExtractor and LLMExtractor
Appointat Oct 29, 2024
a4e602e
Refactor knowledge graph extraction batch size and method in Communit…
Appointat Oct 29, 2024
7d4d7f4
refactor: Refactor batch extraction methods in GraphExtractor and LLM…
Appointat Oct 29, 2024
5aaa393
feat: Refactor knowledge graph extraction batch size and method in Tu…
Appointat Oct 29, 2024
e8b82db
refactor: Update knowledge graph extraction batch size and method in …
Appointat Oct 29, 2024
0b87218
Refactor method signature in TuGraphStoreAdapter
Appointat Oct 29, 2024
e6f6d33
Refactor markdown format in community_summary.py
Appointat Oct 29, 2024
1ff3184
fix: Refactor graph store configuration and enable/disable graph search
Appointat Oct 30, 2024
a8f9321
chore: format the code
Appointat Oct 30, 2024
7e3c3c7
fix: Refactor TuGraphStoreAdapter to improve graph retrieval logic
Appointat Oct 30, 2024
0c263bf
fix
Appointat Oct 30, 2024
f0216d7
Refactor markdown format in community_summary.py
Appointat Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks

KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
TRIPLET_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text
Appointat marked this conversation as resolved.
Show resolved Hide resolved

### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
Expand Down
14 changes: 13 additions & 1 deletion dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Transformer base class."""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Union

from dbgpt.core import Chunk

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,6 +40,15 @@ class ExtractorBase(TransformerBase, ABC):
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text."""

@abstractmethod
async def batch_extract(
self,
texts: Union[List[str], List[Chunk]],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract results from texts."""


class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""
83 changes: 66 additions & 17 deletions dbgpt/rag/transformer/graph_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""GraphExtractor class."""

import asyncio
import logging
import re
from typing import List, Optional
from typing import Dict, List, Optional, Tuple, Union

from dbgpt.core import Chunk, LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
Expand All @@ -21,37 +22,85 @@ def __init__(
"""Initialize the GraphExtractor."""
super().__init__(llm_client, model_name, GRAPH_EXTRACT_PT_CN)
self._chunk_history = chunk_history
self._chunk_context_map: Dict[str, str] = {}

config = self._chunk_history.get_config()

self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Load similar chunks."""
# load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
context = "\n".join(history) if history else ""

try:
# extract with chunk history
return await super()._extract(text, context, limit)
async def aload_chunk_context(self, texts: List[str]) -> None:
"""Load chunk context."""
for text in texts:
# Load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [f"Section {i + 1}:\n{chunk}" for i, chunk in enumerate(chunks)]

finally:
# save chunk to history
# Save chunk to history
await self._chunk_history.aload_document_with_limit(
[Chunk(content=text, metadata={"relevant_cnt": len(history)})],
self._max_chunks_once_load,
self._max_threads,
)

# Save chunk context to map
context = "\n".join(history) if history else ""
self._chunk_context_map[text] = context

async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Load similar chunks.

Suggestion: to extract triplets in batches, call `batch_extract`.
"""
if text not in self._chunk_context_map:
await self.aload_chunk_context([Chunk(content=text)])
context = self._chunk_context_map.get(text, "")

# Extract with chunk history
return await super()._extract(text, context, limit)

async def batch_extract(
self,
texts: Union[List[str], List[Chunk]],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List[Tuple[Chunk, List[Graph]]]:
"""Extract graphs from chunks in batches."""
if isinstance(texts, list) and any(
not isinstance(chunk, Chunk) for chunk in texts
):
raise ValueError("Chunks should be a list of Chunk objects, not strings.")
chunks: List[Chunk] = texts # type: ignore[assignment]

# 1. Load chunk context
chunk_content_list = [chunk.content for chunk in chunks]
await self.aload_chunk_context(chunk_content_list)

chunk_graph_pairs: List[Tuple[Chunk, List[Graph]]] = []
total_batches = (len(chunks) + batch_size - 1) // batch_size
Appointat marked this conversation as resolved.
Show resolved Hide resolved

for batch_idx in range(total_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(chunks))
batch_chunks = chunks[start_idx:end_idx]

# 2. Process extraction in parallel
extraction_tasks = [
self.extract(chunk.content, limit) for chunk in batch_chunks
]
batch_graphs = await asyncio.gather(*extraction_tasks)

# 3. Zip chunks with their corresponding graphs to maintain the relationship
batch_graph_pairs = list(zip(batch_chunks, batch_graphs))
chunk_graph_pairs.extend(batch_graph_pairs)

return chunk_graph_pairs

def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]:
graph = MemoryGraph()
edge_count = 0
Expand Down
20 changes: 18 additions & 2 deletions dbgpt/rag/transformer/llm_extractor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""TripletExtractor class."""

import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Union

from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.core import Chunk, HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.rag.transformer.base import ExtractorBase

logger = logging.getLogger(__name__)
Expand All @@ -22,6 +23,21 @@ async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract by LLM."""
return await self._extract(text, None, limit)

async def batch_extract(
self,
texts: Union[List[str], List[Chunk]],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract by LLM."""
if isinstance(texts, list) and any(not isinstance(text, str) for text in texts):
raise ValueError("All elements must be strings")

results = []
for text in texts:
results.append(await self.extract(text, limit))
return results

async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None
) -> List:
Expand Down
3 changes: 2 additions & 1 deletion dbgpt/rag/transformer/triplet_extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TripletExtractor class."""

import logging
import re
from typing import Any, List, Optional, Tuple
Expand All @@ -12,7 +13,7 @@
"Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n"
"Avoid stopwords. The subject, predicate, object can not be none.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"
Expand Down
23 changes: 15 additions & 8 deletions dbgpt/storage/knowledge_graph/community_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
default=5,
description="Top size of knowledge graph chunk search",
)
triplet_extraction_batch_size: int = Field(
default=20,
description="Batch size of triplets extraction from the text",
)


class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
Expand Down Expand Up @@ -96,6 +100,11 @@ def __init__(self, config: CommunitySummaryKnowledgeGraphConfig):
config.community_score_threshold,
)
)
self._triplet_extraction_batch_size = int(
os.getenv(
"TRIPLET_EXTRACTION_BATCH_SIZE", config.triplet_extraction_batch_size
)
)

def extractor_configure(name: str, cfg: VectorStoreConfig):
cfg.name = name
Expand Down Expand Up @@ -189,22 +198,20 @@ async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
return

document_graph_enabled = self._graph_store.get_config().document_graph_enabled
for chunk in chunks:
# TODO: Use asyncio to extract graph to accelerate the process
# (attention to the CAP of the graph db)

graphs: List[MemoryGraph] = await self._graph_extractor.extract(
chunk.content
)
chunk_graph_pairs = await self._graph_extractor.batch_extract(
chunks, batch_size=self._triplet_extraction_batch_size
)

for chunk, graphs in chunk_graph_pairs:
for graph in graphs:
if document_graph_enabled:
# append the chunk id to the edge
# Append the chunk id to the edge
for edge in graph.edges():
edge.set_prop("_chunk_id", chunk.chunk_id)
graph.append_edge(edge=edge)

# upsert the graph
# Upsert the graph
Appointat marked this conversation as resolved.
Show resolved Hide resolved
self._graph_store_apdater.upsert_graph(graph)

# chunk -> include -> entity
Expand Down
1 change: 1 addition & 0 deletions docs/docs/cookbook/rag/graph_rag_app_develop.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
TRIPLET_GRAPH_ENABLED=True # enable the graph search for the triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the number of the searched triplets in a retrieval
TRIPLET_EXTRACTION_BATCH_SIZE=20 # the batch size of triplet extraction from the text
```


Expand Down
Loading