Skip to content

Commit

Permalink
feat: ChatKnowledge add prompt token count
Browse files Browse the repository at this point in the history
  • Loading branch information
Aries-ckt committed Dec 29, 2023
1 parent bae7def commit 3bd625e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
24 changes: 19 additions & 5 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def create_knowledge_document(self, space, request: KnowledgeDocumentRequest):
content=request.content,
result="",
)
return knowledge_document_dao.create_knowledge_document(document)
doc_id = knowledge_document_dao.create_knowledge_document(document)
if doc_id is None:
raise Exception(f"create document failed, {request.doc_name}")
return doc_id


def get_knowledge_space(self, request: KnowledgeSpaceRequest):
"""get knowledge space
Expand Down Expand Up @@ -229,10 +233,20 @@ def batch_document_sync(
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
# space_context = self.get_space_context(space_name)
self._sync_knowledge_document(
space_name, doc, sync_request.chunk_parameters
)
chunk_parameters = sync_request.chunk_parameters
if "Automatic" == chunk_parameters.chunk_strategy:
space_context = self.get_space_context(space_name)
chunk_parameters.chunk_size = (
CFG.KNOWLEDGE_CHUNK_SIZE
if space_context is None
else int(space_context["embedding"]["chunk_size"])
)
chunk_parameters.chunk_overlap = (
CFG.KNOWLEDGE_CHUNK_OVERLAP
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
self._sync_knowledge_document(space_name, doc, chunk_parameters)
doc_ids.append(doc.id)
return doc_ids

Expand Down
30 changes: 22 additions & 8 deletions dbgpt/app/scene/chat_knowledge/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.prompt_util import PromptHelper
from dbgpt.util.tracer import trace

CFG = Config()
Expand Down Expand Up @@ -78,13 +79,15 @@ def __init__(self, chat_param: Dict):
vector_store_config=config,
)
query_rewrite = None
self.worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager)
if CFG.KNOWLEDGE_SEARCH_REWRITE:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
llm_client = DefaultLLMClient(worker_manager=worker_manager)
query_rewrite = QueryRewrite(
llm_client=llm_client, model_name=self.llm_model, language=CFG.LANGUAGE
llm_client=self.llm_client,
model_name=self.llm_model,
language=CFG.LANGUAGE,
)
self.embedding_retriever = EmbeddingRetriever(
top_k=self.top_k,
Expand Down Expand Up @@ -149,9 +152,7 @@ async def generate_input_values(self) -> Dict:
if len(chucks) > 0:
self.chunks_with_score.append((chucks[0], chunk.score))

context = [doc.content for doc in candidates_with_scores]

context = context[: self.max_token]
context = "\n".join([doc.content for doc in candidates_with_scores])
self.relations = list(
set(
[
Expand All @@ -165,6 +166,19 @@ async def generate_input_values(self) -> Dict:
"question": self.current_user_input,
"relations": self.relations,
}
prompt = self.prompt_template.format(**input_values)
model_metadata = await self.worker_manager.get_model_metadata(
{"model": self.llm_model}
)
try:
current_token_count = self.llm_client.count_token(prompt)
except Exception:
prompt_util = PromptHelper(context_window=model_metadata.context_length)
current_token_count = prompt_util.token_count(prompt)
if current_token_count > model_metadata.context_length:
input_values["context"] = input_values["context"][
: current_token_count - model_metadata.context_length
]
return input_values

def parse_source_view(self, chunks_with_score: List):
Expand Down
5 changes: 5 additions & 0 deletions dbgpt/util/prompt_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def __init__(
separator=separator,
)

def token_count(self, prompt_template: str) -> int:
"""Get token count of prompt template."""
empty_prompt_txt = get_empty_prompt_txt(prompt_template)
return len(self._tokenizer(empty_prompt_txt))

@classmethod
def from_llm_metadata(
cls,
Expand Down

0 comments on commit 3bd625e

Please sign in to comment.