Skip to content

Commit

Permalink
perf(rag): Support load large document (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Mar 1, 2024
1 parent ed4df23 commit 505bc32
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 37 deletions.
4 changes: 4 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ EMBEDDING_MODEL=text2vec
#EMBEDDING_MODEL=bge-large-zh
KNOWLEDGE_CHUNK_SIZE=500
KNOWLEDGE_SEARCH_TOP_SIZE=5
## Maximum number of chunks to load at once, if your single document is too large,
## you can set this value to a higher value for better performance.
## if out of memory when load large document, you can set this value to a lower value.
# KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD=10
#KNOWLEDGE_CHUNK_OVERLAP=50
# Control whether to display the source document of knowledge on the front end.
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
Expand Down
3 changes: 3 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def __init__(self) -> None:
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
self.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD = int(
os.getenv("KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD", 10)
)
# default recall similarity score, between 0 and 1
self.KNOWLEDGE_SEARCH_RECALL_SCORE = float(
os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3)
Expand Down
14 changes: 12 additions & 2 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.tracer import root_tracer, trace

knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
Expand Down Expand Up @@ -335,7 +336,11 @@ def _sync_knowledge_document(
)
from dbgpt.storage.vector_store.base import VectorStoreConfig

config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn)
config = VectorStoreConfig(
name=space_name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
Expand Down Expand Up @@ -499,6 +504,7 @@ def get_document_chunks(self, request: ChunkQueryRequest):
res.page = request.page
return res

@trace("async_doc_embedding")
def async_doc_embedding(self, assembler, chunk_docs, doc):
"""async document embedding into vector db
Args:
Expand All @@ -511,7 +517,11 @@ def async_doc_embedding(self, assembler, chunk_docs, doc):
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}"
)
try:
vector_ids = assembler.persist()
with root_tracer.start_span(
"app.knowledge.assembler.persist",
metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)},
):
vector_ids = assembler.persist()
doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success"
if vector_ids is not None:
Expand Down
14 changes: 7 additions & 7 deletions dbgpt/app/scene/chat_knowledge/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@
)
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import (
ChatPromptTemplate,
HumanPromptTemplate,
MessagesPlaceholder,
SystemPromptTemplate,
)
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.tracer import trace
from dbgpt.util.tracer import root_tracer, trace

CFG = Config()

Expand Down Expand Up @@ -226,6 +223,9 @@ def get_space_context(self, space_name):

async def execute_similar_search(self, query):
"""execute similarity search"""
return await self.embedding_retriever.aretrieve_with_scores(
query, self.recall_score
)
with root_tracer.start_span(
"execute_similar_search", metadata={"query": query}
):
return await self.embedding_retriever.aretrieve_with_scores(
query, self.recall_score
)
2 changes: 2 additions & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def get_device() -> str:
# https://huggingface.co/BAAI/bge-large-zh
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"),
"gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",
Expand Down
6 changes: 4 additions & 2 deletions dbgpt/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
try:
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.model.cluster.client import DefaultLLMClient, RemoteLLMClient
except ImportError as exc:
# logging.warning("Can't import dbgpt.model.DefaultLLMClient")
DefaultLLMClient = None
RemoteLLMClient = None


_exports = []
if DefaultLLMClient:
_exports.append("DefaultLLMClient")
if RemoteLLMClient:
_exports.append("RemoteLLMClient")

__ALL__ = _exports
57 changes: 57 additions & 0 deletions dbgpt/model/cluster/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,60 @@ async def models(self) -> List[ModelMetadata]:

async def count_token(self, model: str, prompt: str) -> int:
return await self.worker_manager.count_token({"model": model, "prompt": prompt})


@register_resource(
label="Remote LLM Client",
name="remote_llm_client",
category=ResourceCategory.LLM_CLIENT,
description="Remote LLM client(Connect to the remote DB-GPT model serving)",
parameters=[
Parameter.build_from(
"Controller Address",
name="controller_address",
type=str,
optional=True,
default="http://127.0.0.1:8000",
description="Model controller address",
),
Parameter.build_from(
"Auto Convert Message",
name="auto_convert_message",
type=bool,
optional=True,
default=False,
description="Whether to auto convert the messages that are not supported "
"by the LLM to a compatible format",
),
],
)
class RemoteLLMClient(DefaultLLMClient):
"""Remote LLM client implementation.
Connect to the remote worker manager and send the request to the remote worker manager.
Args:
controller_address (str): model controller address
auto_convert_message (bool, optional): auto convert the message to
ModelRequest. Defaults to False.
If you start DB-GPT model cluster, the controller address is the address of the
Model Controller(`dbgpt start controller`, the default port of model controller
is 8000).
Otherwise, if you already have a running DB-GPT server(start it by
`dbgpt start webserver --port ${remote_port}`), you can use the address of the
`http://${remote_ip}:${remote_port}`.
"""

def __init__(
self,
controller_address: str = "http://127.0.0.1:8000",
auto_convert_message: bool = False,
):
"""Initialize the RemoteLLMClient."""
from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager

model_registry_client = ModelRegistryClient(controller_address)
worker_manager = RemoteWorkerManager(model_registry_client)
super().__init__(worker_manager, auto_convert_message)
55 changes: 39 additions & 16 deletions dbgpt/rag/retriever/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.util.tracer import root_tracer


class EmbeddingRetriever(BaseRetriever):
Expand Down Expand Up @@ -129,23 +130,45 @@ async def _aretrieve_with_score(
"""
queries = [query]
if self._query_rewrite:
candidates_tasks = [self._similarity_search(query) for query in queries]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
with root_tracer.start_span(
"EmbeddingRetriever.query_rewrite.similarity_search",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_tasks = [self._similarity_search(query) for query in queries]
chunks = await self._run_async_tasks(candidates_tasks)
context = "\n".join([chunk.content for chunk in chunks])
with root_tracer.start_span(
"EmbeddingRetriever.query_rewrite.rewrite",
metadata={"query": query, "context": context, "nums": 1},
):
new_queries = await self._query_rewrite.rewrite(
origin_query=query, context=context, nums=1
)
queries.extend(new_queries)

with root_tracer.start_span(
"EmbeddingRetriever.similarity_search_with_score",
metadata={"query": query, "score_threshold": score_threshold},
):
candidates_with_score = [
self._similarity_search_with_score(query, score_threshold)
for query in queries
]
candidates_with_score = await run_async_tasks(
tasks=candidates_with_score, concurrency_limit=1
)
queries.extend(new_queries)
candidates_with_score = [
self._similarity_search_with_score(query, score_threshold)
for query in queries
]
candidates_with_score = await run_async_tasks(
tasks=candidates_with_score, concurrency_limit=1
)
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)
candidates_with_score = self._rerank.rank(candidates_with_score)
return candidates_with_score
candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score)

with root_tracer.start_span(
"EmbeddingRetriever.rerank",
metadata={
"query": query,
"score_threshold": score_threshold,
"rerank_cls": self._rerank.__class__.__name__,
},
):
candidates_with_score = self._rerank.rank(candidates_with_score)
return candidates_with_score

async def _similarity_search(self, query) -> List[Chunk]:
"""Similar search."""
Expand Down
11 changes: 7 additions & 4 deletions dbgpt/rag/text_splitter/text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import logging
import re
from abc import ABC, abstractmethod
from typing import (
Any,
Expand Down Expand Up @@ -66,10 +65,14 @@ def create_documents(
chunks.append(new_doc)
return chunks

def split_documents(self, documents: List[Document], **kwargs) -> List[Chunk]:
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
"""Split documents."""
texts = [doc.content for doc in documents]
metadatas = [doc.metadata for doc in documents]
texts = []
metadatas = []
for doc in documents:
# Iterable just supports one iteration
texts.append(doc.content)
metadatas.append(doc.metadata)
return self.create_documents(texts, metadatas, **kwargs)

def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]:
Expand Down
20 changes: 17 additions & 3 deletions dbgpt/serve/rag/assembler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dbgpt.rag.extractor.base import Extractor
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.util.tracer import root_tracer, trace


class BaseAssembler(ABC):
Expand All @@ -30,12 +31,25 @@ def __init__(
knowledge=self._knowledge, chunk_parameter=self._chunk_parameters
)
self._chunks = None
self.load_knowledge(self._knowledge)
metadata = {
"knowledge_cls": self._knowledge.__class__.__name__
if self._knowledge
else None,
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
"path": self._knowledge._path
if self._knowledge and hasattr(self._knowledge, "_path")
else None,
"chunk_parameters": self._chunk_parameters.dict(),
}
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
self.load_knowledge(self._knowledge)

def load_knowledge(self, knowledge) -> None:
"""Load knowledge Pipeline."""
documents = knowledge.load()
self._chunks = self._chunk_manager.split(documents)
with root_tracer.start_span("BaseAssembler.knowledge.load"):
documents = knowledge.load()
with root_tracer.start_span("BaseAssembler.chunk_manager.split"):
self._chunks = self._chunk_manager.split(documents)

@abstractmethod
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
Expand Down
37 changes: 37 additions & 0 deletions dbgpt/storage/vector_store/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
import math
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional

from pydantic import BaseModel, Field

from dbgpt.rag.chunk import Chunk

logger = logging.getLogger(__name__)


class VectorStoreConfig(BaseModel):
"""Vector store config."""
Expand All @@ -26,6 +30,12 @@ class VectorStoreConfig(BaseModel):
default=None,
description="The embedding function of vector store, if not set, will use the default embedding function.",
)
max_chunks_once_load: int = Field(
default=10,
description="The max number of chunks to load at once. If your document is "
"large, you can set this value to a larger number to speed up the loading "
"process. Default is 10.",
)


class VectorStoreBase(ABC):
Expand All @@ -41,6 +51,33 @@ def load_document(self, chunks: List[Chunk]) -> List[str]:
"""
pass

def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10
) -> List[str]:
"""load document in vector database with limit.
Args:
chunks: document chunks.
max_chunks_once_load: Max number of chunks to load at once.
Return:
"""
# Group the chunks into chunks of size max_chunks
chunk_groups = [
chunks[i : i + max_chunks_once_load]
for i in range(0, len(chunks), max_chunks_once_load)
]
logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
ids = []
loaded_cnt = 0
start_time = time.time()
for chunk_group in chunk_groups:
ids.extend(self.load_document(chunk_group))
loaded_cnt += len(chunk_group)
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
logger.info(
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
)
return ids

@abstractmethod
def similar_search(self, text, topk) -> List[Chunk]:
"""similar search in vector database.
Expand Down
4 changes: 3 additions & 1 deletion dbgpt/storage/vector_store/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def load_document(self, chunks: List[Chunk]) -> List[str]:
- chunks: document chunks.
Return chunk ids.
"""
return self.client.load_document(chunks)
return self.client.load_document_with_limit(
chunks, self._vector_store_config.max_chunks_once_load
)

def similar_search(self, doc: str, topk: int) -> List[Chunk]:
"""similar search in vector database.
Expand Down
Loading

0 comments on commit 505bc32

Please sign in to comment.