From 505bc32775032425b971c2d391e7673970b76ffd Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 1 Mar 2024 22:33:52 +0800 Subject: [PATCH] perf(rag): Support load large document (#1233) --- .env.template | 4 ++ dbgpt/_private/config.py | 3 ++ dbgpt/app/knowledge/service.py | 14 +++++- dbgpt/app/scene/chat_knowledge/v1/chat.py | 14 +++--- dbgpt/configs/model_config.py | 2 + dbgpt/model/__init__.py | 6 ++- dbgpt/model/cluster/client.py | 57 +++++++++++++++++++++++ dbgpt/rag/retriever/embedding.py | 55 +++++++++++++++------- dbgpt/rag/text_splitter/text_splitter.py | 11 +++-- dbgpt/serve/rag/assembler/base.py | 20 ++++++-- dbgpt/storage/vector_store/base.py | 37 +++++++++++++++ dbgpt/storage/vector_store/connector.py | 4 +- dbgpt/util/tracer/base.py | 41 +++++++++++++++- 13 files changed, 231 insertions(+), 37 deletions(-) diff --git a/.env.template b/.env.template index 046c4be56..ba5a4bb51 100644 --- a/.env.template +++ b/.env.template @@ -75,6 +75,10 @@ EMBEDDING_MODEL=text2vec #EMBEDDING_MODEL=bge-large-zh KNOWLEDGE_CHUNK_SIZE=500 KNOWLEDGE_SEARCH_TOP_SIZE=5 +## Maximum number of chunks to load at once, if your single document is too large, +## you can set this value to a higher value for better performance. +## if out of memory when load large document, you can set this value to a lower value. +# KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD=10 #KNOWLEDGE_CHUNK_OVERLAP=50 # Control whether to display the source document of knowledge on the front end. KNOWLEDGE_CHAT_SHOW_RELATIONS=False diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index f1b418229..f39b1921e 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -233,6 +233,9 @@ def __init__(self) -> None: self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100)) self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50)) self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5)) + self.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD = int( + os.getenv("KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD", 10) + ) # default recall similarity score, between 0 and 1 self.KNOWLEDGE_SEARCH_RECALL_SCORE = float( os.getenv("KNOWLEDGE_SEARCH_RECALL_SCORE", 0.3) diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index a10869f66..8e34e47cc 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -43,6 +43,7 @@ from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async +from dbgpt.util.tracer import root_tracer, trace knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() @@ -335,7 +336,11 @@ def _sync_knowledge_document( ) from dbgpt.storage.vector_store.base import VectorStoreConfig - config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn) + config = VectorStoreConfig( + name=space_name, + embedding_fn=embedding_fn, + max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD, + ) vector_store_connector = VectorStoreConnector( vector_store_type=CFG.VECTOR_STORE_TYPE, vector_store_config=config, @@ -499,6 +504,7 @@ def get_document_chunks(self, request: ChunkQueryRequest): res.page = request.page return res + @trace("async_doc_embedding") def async_doc_embedding(self, assembler, chunk_docs, doc): """async document embedding into vector db Args: @@ -511,7 +517,11 @@ def async_doc_embedding(self, assembler, chunk_docs, doc): f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}, begin embedding to vector store-{CFG.VECTOR_STORE_TYPE}" ) try: - vector_ids = assembler.persist() + with root_tracer.start_span( + "app.knowledge.assembler.persist", + metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)}, + ): + vector_ids = assembler.persist() doc.status = SyncStatus.FINISHED.name doc.result = "document embedding success" if vector_ids is not None: diff --git a/dbgpt/app/scene/chat_knowledge/v1/chat.py b/dbgpt/app/scene/chat_knowledge/v1/chat.py index 6f60e46e0..80bad1429 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/chat.py +++ b/dbgpt/app/scene/chat_knowledge/v1/chat.py @@ -11,7 +11,6 @@ ) from dbgpt.app.knowledge.service import KnowledgeService from dbgpt.app.scene import BaseChat, ChatScene -from dbgpt.component import ComponentType from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG from dbgpt.core import ( ChatPromptTemplate, @@ -19,10 +18,8 @@ MessagesPlaceholder, SystemPromptTemplate, ) -from dbgpt.model import DefaultLLMClient -from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.rag.retriever.rewrite import QueryRewrite -from dbgpt.util.tracer import trace +from dbgpt.util.tracer import root_tracer, trace CFG = Config() @@ -226,6 +223,9 @@ def get_space_context(self, space_name): async def execute_similar_search(self, query): """execute similarity search""" - return await self.embedding_retriever.aretrieve_with_scores( - query, self.recall_score - ) + with root_tracer.start_span( + "execute_similar_search", metadata={"query": query} + ): + return await self.embedding_retriever.aretrieve_with_scores( + query, self.recall_score + ) diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index c6d475cb8..c4d70ec4c 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -167,6 +167,8 @@ def get_device() -> str: # https://huggingface.co/BAAI/bge-large-zh "bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"), "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), + "gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"), + "gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "proxy_openai": "proxy_openai", "proxy_azure": "proxy_azure", diff --git a/dbgpt/model/__init__.py b/dbgpt/model/__init__.py index 0e66b2ff1..13bf3f0e3 100644 --- a/dbgpt/model/__init__.py +++ b/dbgpt/model/__init__.py @@ -1,12 +1,14 @@ try: - from dbgpt.model.cluster.client import DefaultLLMClient + from dbgpt.model.cluster.client import DefaultLLMClient, RemoteLLMClient except ImportError as exc: - # logging.warning("Can't import dbgpt.model.DefaultLLMClient") DefaultLLMClient = None + RemoteLLMClient = None _exports = [] if DefaultLLMClient: _exports.append("DefaultLLMClient") +if RemoteLLMClient: + _exports.append("RemoteLLMClient") __ALL__ = _exports diff --git a/dbgpt/model/cluster/client.py b/dbgpt/model/cluster/client.py index 1a7d99aa3..4e1ed2e4e 100644 --- a/dbgpt/model/cluster/client.py +++ b/dbgpt/model/cluster/client.py @@ -104,3 +104,60 @@ async def models(self) -> List[ModelMetadata]: async def count_token(self, model: str, prompt: str) -> int: return await self.worker_manager.count_token({"model": model, "prompt": prompt}) + + +@register_resource( + label="Remote LLM Client", + name="remote_llm_client", + category=ResourceCategory.LLM_CLIENT, + description="Remote LLM client(Connect to the remote DB-GPT model serving)", + parameters=[ + Parameter.build_from( + "Controller Address", + name="controller_address", + type=str, + optional=True, + default="http://127.0.0.1:8000", + description="Model controller address", + ), + Parameter.build_from( + "Auto Convert Message", + name="auto_convert_message", + type=bool, + optional=True, + default=False, + description="Whether to auto convert the messages that are not supported " + "by the LLM to a compatible format", + ), + ], +) +class RemoteLLMClient(DefaultLLMClient): + """Remote LLM client implementation. + + Connect to the remote worker manager and send the request to the remote worker manager. + + Args: + controller_address (str): model controller address + auto_convert_message (bool, optional): auto convert the message to + ModelRequest. Defaults to False. + + If you start DB-GPT model cluster, the controller address is the address of the + Model Controller(`dbgpt start controller`, the default port of model controller + is 8000). + Otherwise, if you already have a running DB-GPT server(start it by + `dbgpt start webserver --port ${remote_port}`), you can use the address of the + `http://${remote_ip}:${remote_port}`. + + """ + + def __init__( + self, + controller_address: str = "http://127.0.0.1:8000", + auto_convert_message: bool = False, + ): + """Initialize the RemoteLLMClient.""" + from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager + + model_registry_client = ModelRegistryClient(controller_address) + worker_manager = RemoteWorkerManager(model_registry_client) + super().__init__(worker_manager, auto_convert_message) diff --git a/dbgpt/rag/retriever/embedding.py b/dbgpt/rag/retriever/embedding.py index a9e24065f..fd4c687b4 100644 --- a/dbgpt/rag/retriever/embedding.py +++ b/dbgpt/rag/retriever/embedding.py @@ -7,6 +7,7 @@ from dbgpt.rag.retriever.rewrite import QueryRewrite from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.util.chat_util import run_async_tasks +from dbgpt.util.tracer import root_tracer class EmbeddingRetriever(BaseRetriever): @@ -129,23 +130,45 @@ async def _aretrieve_with_score( """ queries = [query] if self._query_rewrite: - candidates_tasks = [self._similarity_search(query) for query in queries] - chunks = await self._run_async_tasks(candidates_tasks) - context = "\n".join([chunk.content for chunk in chunks]) - new_queries = await self._query_rewrite.rewrite( - origin_query=query, context=context, nums=1 + with root_tracer.start_span( + "EmbeddingRetriever.query_rewrite.similarity_search", + metadata={"query": query, "score_threshold": score_threshold}, + ): + candidates_tasks = [self._similarity_search(query) for query in queries] + chunks = await self._run_async_tasks(candidates_tasks) + context = "\n".join([chunk.content for chunk in chunks]) + with root_tracer.start_span( + "EmbeddingRetriever.query_rewrite.rewrite", + metadata={"query": query, "context": context, "nums": 1}, + ): + new_queries = await self._query_rewrite.rewrite( + origin_query=query, context=context, nums=1 + ) + queries.extend(new_queries) + + with root_tracer.start_span( + "EmbeddingRetriever.similarity_search_with_score", + metadata={"query": query, "score_threshold": score_threshold}, + ): + candidates_with_score = [ + self._similarity_search_with_score(query, score_threshold) + for query in queries + ] + candidates_with_score = await run_async_tasks( + tasks=candidates_with_score, concurrency_limit=1 ) - queries.extend(new_queries) - candidates_with_score = [ - self._similarity_search_with_score(query, score_threshold) - for query in queries - ] - candidates_with_score = await run_async_tasks( - tasks=candidates_with_score, concurrency_limit=1 - ) - candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score) - candidates_with_score = self._rerank.rank(candidates_with_score) - return candidates_with_score + candidates_with_score = reduce(lambda x, y: x + y, candidates_with_score) + + with root_tracer.start_span( + "EmbeddingRetriever.rerank", + metadata={ + "query": query, + "score_threshold": score_threshold, + "rerank_cls": self._rerank.__class__.__name__, + }, + ): + candidates_with_score = self._rerank.rank(candidates_with_score) + return candidates_with_score async def _similarity_search(self, query) -> List[Chunk]: """Similar search.""" diff --git a/dbgpt/rag/text_splitter/text_splitter.py b/dbgpt/rag/text_splitter/text_splitter.py index 21b8cae02..75d74fb41 100644 --- a/dbgpt/rag/text_splitter/text_splitter.py +++ b/dbgpt/rag/text_splitter/text_splitter.py @@ -1,6 +1,5 @@ import copy import logging -import re from abc import ABC, abstractmethod from typing import ( Any, @@ -66,10 +65,14 @@ def create_documents( chunks.append(new_doc) return chunks - def split_documents(self, documents: List[Document], **kwargs) -> List[Chunk]: + def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]: """Split documents.""" - texts = [doc.content for doc in documents] - metadatas = [doc.metadata for doc in documents] + texts = [] + metadatas = [] + for doc in documents: + # Iterable just supports one iteration + texts.append(doc.content) + metadatas.append(doc.metadata) return self.create_documents(texts, metadatas, **kwargs) def _join_docs(self, docs: List[str], separator: str, **kwargs) -> Optional[str]: diff --git a/dbgpt/serve/rag/assembler/base.py b/dbgpt/serve/rag/assembler/base.py index 1f9a1de4b..31aa3ce25 100644 --- a/dbgpt/serve/rag/assembler/base.py +++ b/dbgpt/serve/rag/assembler/base.py @@ -6,6 +6,7 @@ from dbgpt.rag.extractor.base import Extractor from dbgpt.rag.knowledge.base import Knowledge from dbgpt.rag.retriever.base import BaseRetriever +from dbgpt.util.tracer import root_tracer, trace class BaseAssembler(ABC): @@ -30,12 +31,25 @@ def __init__( knowledge=self._knowledge, chunk_parameter=self._chunk_parameters ) self._chunks = None - self.load_knowledge(self._knowledge) + metadata = { + "knowledge_cls": self._knowledge.__class__.__name__ + if self._knowledge + else None, + "knowledge_type": self._knowledge.type().value if self._knowledge else None, + "path": self._knowledge._path + if self._knowledge and hasattr(self._knowledge, "_path") + else None, + "chunk_parameters": self._chunk_parameters.dict(), + } + with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata): + self.load_knowledge(self._knowledge) def load_knowledge(self, knowledge) -> None: """Load knowledge Pipeline.""" - documents = knowledge.load() - self._chunks = self._chunk_manager.split(documents) + with root_tracer.start_span("BaseAssembler.knowledge.load"): + documents = knowledge.load() + with root_tracer.start_span("BaseAssembler.chunk_manager.split"): + self._chunks = self._chunk_manager.split(documents) @abstractmethod def as_retriever(self, **kwargs: Any) -> BaseRetriever: diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py index 2d2918f71..a97fdd601 100644 --- a/dbgpt/storage/vector_store/base.py +++ b/dbgpt/storage/vector_store/base.py @@ -1,4 +1,6 @@ +import logging import math +import time from abc import ABC, abstractmethod from typing import Any, Callable, List, Optional @@ -6,6 +8,8 @@ from dbgpt.rag.chunk import Chunk +logger = logging.getLogger(__name__) + class VectorStoreConfig(BaseModel): """Vector store config.""" @@ -26,6 +30,12 @@ class VectorStoreConfig(BaseModel): default=None, description="The embedding function of vector store, if not set, will use the default embedding function.", ) + max_chunks_once_load: int = Field( + default=10, + description="The max number of chunks to load at once. If your document is " + "large, you can set this value to a larger number to speed up the loading " + "process. Default is 10.", + ) class VectorStoreBase(ABC): @@ -41,6 +51,33 @@ def load_document(self, chunks: List[Chunk]) -> List[str]: """ pass + def load_document_with_limit( + self, chunks: List[Chunk], max_chunks_once_load: int = 10 + ) -> List[str]: + """load document in vector database with limit. + Args: + chunks: document chunks. + max_chunks_once_load: Max number of chunks to load at once. + Return: + """ + # Group the chunks into chunks of size max_chunks + chunk_groups = [ + chunks[i : i + max_chunks_once_load] + for i in range(0, len(chunks), max_chunks_once_load) + ] + logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups") + ids = [] + loaded_cnt = 0 + start_time = time.time() + for chunk_group in chunk_groups: + ids.extend(self.load_document(chunk_group)) + loaded_cnt += len(chunk_group) + logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.") + logger.info( + f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds" + ) + return ids + @abstractmethod def similar_search(self, text, topk) -> List[Chunk]: """similar search in vector database. diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py index 50e7148a0..ddcfc2f06 100644 --- a/dbgpt/storage/vector_store/connector.py +++ b/dbgpt/storage/vector_store/connector.py @@ -64,7 +64,9 @@ def load_document(self, chunks: List[Chunk]) -> List[str]: - chunks: document chunks. Return chunk ids. """ - return self.client.load_document(chunks) + return self.client.load_document_with_limit( + chunks, self._vector_store_config.max_chunks_once_load + ) def similar_search(self, doc: str, topk: int) -> List[Chunk]: """similar search in vector database. diff --git a/dbgpt/util/tracer/base.py b/dbgpt/util/tracer/base.py index 77f8049f1..95d5ac09c 100644 --- a/dbgpt/util/tracer/base.py +++ b/dbgpt/util/tracer/base.py @@ -1,11 +1,12 @@ from __future__ import annotations +import json import uuid from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional from dbgpt.component import BaseComponent, ComponentType, SystemApp @@ -95,7 +96,7 @@ def to_dict(self) -> Dict: "end_time": None if not self.end_time else self.end_time.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], - "metadata": self.metadata, + "metadata": _clean_for_json(self.metadata), } @@ -187,3 +188,39 @@ def _new_uuid(self) -> str: @dataclass class TracerContext: span_id: Optional[str] = None + + +def _clean_for_json(data: Optional[str, Any] = None): + if not data: + return None + if isinstance(data, dict): + cleaned_dict = {} + for key, value in data.items(): + # Try to clean the sub-items + cleaned_value = _clean_for_json(value) + if cleaned_value is not None: + # Only add to the cleaned dict if it's not None + try: + json.dumps({key: cleaned_value}) + cleaned_dict[key] = cleaned_value + except TypeError: + # Skip this key-value pair if it can't be serialized + pass + return cleaned_dict + elif isinstance(data, list): + cleaned_list = [] + for item in data: + cleaned_item = _clean_for_json(item) + if cleaned_item is not None: + try: + json.dumps(cleaned_item) + cleaned_list.append(cleaned_item) + except TypeError: + pass + return cleaned_list + else: + try: + json.dumps(data) + return data + except TypeError: + return None