From 3bd625e8de78f28fa12c2dab6e555302302f0a4e Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Fri, 29 Dec 2023 10:08:20 +0800 Subject: [PATCH] feat: ChatKnowledge add prompt token count --- dbgpt/app/knowledge/service.py | 24 ++++++++++++++---- dbgpt/app/scene/chat_knowledge/v1/chat.py | 30 +++++++++++++++++------ dbgpt/util/prompt_util.py | 5 ++++ 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index dcccb3207..e9100bff7 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -120,7 +120,11 @@ def create_knowledge_document(self, space, request: KnowledgeDocumentRequest): content=request.content, result="", ) - return knowledge_document_dao.create_knowledge_document(document) + doc_id = knowledge_document_dao.create_knowledge_document(document) + if doc_id is None: + raise Exception(f"create document failed, {request.doc_name}") + return doc_id + def get_knowledge_space(self, request: KnowledgeSpaceRequest): """get knowledge space @@ -229,10 +233,20 @@ def batch_document_sync( raise Exception( f" doc:{doc.doc_name} status is {doc.status}, can not sync" ) - # space_context = self.get_space_context(space_name) - self._sync_knowledge_document( - space_name, doc, sync_request.chunk_parameters - ) + chunk_parameters = sync_request.chunk_parameters + if "Automatic" == chunk_parameters.chunk_strategy: + space_context = self.get_space_context(space_name) + chunk_parameters.chunk_size = ( + CFG.KNOWLEDGE_CHUNK_SIZE + if space_context is None + else int(space_context["embedding"]["chunk_size"]) + ) + chunk_parameters.chunk_overlap = ( + CFG.KNOWLEDGE_CHUNK_OVERLAP + if space_context is None + else int(space_context["embedding"]["chunk_overlap"]) + ) + self._sync_knowledge_document(space_name, doc, chunk_parameters) doc_ids.append(doc.id) return doc_ids diff --git a/dbgpt/app/scene/chat_knowledge/v1/chat.py b/dbgpt/app/scene/chat_knowledge/v1/chat.py index ca611f5e5..1c064f102 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/chat.py +++ b/dbgpt/app/scene/chat_knowledge/v1/chat.py @@ -20,6 +20,7 @@ from dbgpt.model import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.rag.retriever.rewrite import QueryRewrite +from dbgpt.util.prompt_util import PromptHelper from dbgpt.util.tracer import trace CFG = Config() @@ -78,13 +79,15 @@ def __init__(self, chat_param: Dict): vector_store_config=config, ) query_rewrite = None + self.worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager) if CFG.KNOWLEDGE_SEARCH_REWRITE: - worker_manager = CFG.SYSTEM_APP.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - llm_client = DefaultLLMClient(worker_manager=worker_manager) query_rewrite = QueryRewrite( - llm_client=llm_client, model_name=self.llm_model, language=CFG.LANGUAGE + llm_client=self.llm_client, + model_name=self.llm_model, + language=CFG.LANGUAGE, ) self.embedding_retriever = EmbeddingRetriever( top_k=self.top_k, @@ -149,9 +152,7 @@ async def generate_input_values(self) -> Dict: if len(chucks) > 0: self.chunks_with_score.append((chucks[0], chunk.score)) - context = [doc.content for doc in candidates_with_scores] - - context = context[: self.max_token] + context = "\n".join([doc.content for doc in candidates_with_scores]) self.relations = list( set( [ @@ -165,6 +166,19 @@ async def generate_input_values(self) -> Dict: "question": self.current_user_input, "relations": self.relations, } + prompt = self.prompt_template.format(**input_values) + model_metadata = await self.worker_manager.get_model_metadata( + {"model": self.llm_model} + ) + try: + current_token_count = self.llm_client.count_token(prompt) + except Exception: + prompt_util = PromptHelper(context_window=model_metadata.context_length) + current_token_count = prompt_util.token_count(prompt) + if current_token_count > model_metadata.context_length: + input_values["context"] = input_values["context"][ + : current_token_count - model_metadata.context_length + ] return input_values def parse_source_view(self, chunks_with_score: List): diff --git a/dbgpt/util/prompt_util.py b/dbgpt/util/prompt_util.py index 3bf368618..17d994d32 100644 --- a/dbgpt/util/prompt_util.py +++ b/dbgpt/util/prompt_util.py @@ -93,6 +93,11 @@ def __init__( separator=separator, ) + def token_count(self, prompt_template: str) -> int: + """Get token count of prompt template.""" + empty_prompt_txt = get_empty_prompt_txt(prompt_template) + return len(self._tokenizer(empty_prompt_txt)) + @classmethod def from_llm_metadata( cls,