From 820d0c01ec8665642912e992ded99898851a950d Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Thu, 18 Jan 2024 13:11:37 +0800 Subject: [PATCH] fix:rag summary support zhipu message_converter --- dbgpt/app/knowledge/service.py | 4 +++- dbgpt/rag/extractor/summary.py | 4 ++-- examples/awel/simple_rag_summary_example.py | 13 ++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index eba47df37..a10869f66 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -400,7 +400,9 @@ async def document_summary(self, request: DocumentSummaryRequest): assembler = SummaryAssembler( knowledge=knowledge, model_name=request.model_name, - llm_client=DefaultLLMClient(worker_manager=worker_manager), + llm_client=DefaultLLMClient( + worker_manager=worker_manager, auto_convert_message=True + ), language=CFG.LANGUAGE, chunk_parameters=chunk_parameters, ) diff --git a/dbgpt/rag/extractor/summary.py b/dbgpt/rag/extractor/summary.py index c3f84d232..0e1cb1b77 100644 --- a/dbgpt/rag/extractor/summary.py +++ b/dbgpt/rag/extractor/summary.py @@ -18,7 +18,7 @@ the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters. """ -REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context}\n 请根据你之前推理的内容进行最终的总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。""" +REFINE_SUMMARY_TEMPLATE_ZH = """我们已经提供了一个到某一点的现有总结:{context}\n 请根据你之前推理的内容进行总结,总结回答的时候最好按照1.2.3.进行. 注意:请用<中文>来进行总结。""" REFINE_SUMMARY_TEMPLATE_EN = """ We have provided an existing summary up to a certain point: {context}, We have the opportunity to refine the existing summary (only if needed) with some more context below. @@ -144,7 +144,7 @@ async def _llm_run_tasks( from dbgpt.core import ModelMessage prompt = prompt_template.format(context=chunk_text) - messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)] + messages = [ModelMessage(role=ModelMessageRoleType.HUMAN, content=prompt)] request = ModelRequest(model=self._model_name, messages=messages) tasks.append(self._llm_client.generate(request)) summary_results = await run_async_tasks( diff --git a/examples/awel/simple_rag_summary_example.py b/examples/awel/simple_rag_summary_example.py index 5447af032..ca1c145fe 100644 --- a/examples/awel/simple_rag_summary_example.py +++ b/examples/awel/simple_rag_summary_example.py @@ -21,11 +21,9 @@ .. code-block:: shell - DBGPT_SERVER="http://127.0.0.1:5000" - FILE_PATH="{your_file_path}" curl -X POST http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/summary \ -H "Content-Type: application/json" -d '{ - "file_path": $FILE_PATH + "url": "https://docs.dbgpt.site/docs/awel" }' """ from typing import Dict @@ -33,12 +31,13 @@ from dbgpt._private.pydantic import BaseModel, Field from dbgpt.core.awel import DAG, HttpTrigger, MapOperator from dbgpt.model import OpenAILLMClient +from dbgpt.rag.knowledge.base import KnowledgeType from dbgpt.rag.operator.knowledge import KnowledgeOperator from dbgpt.rag.operator.summary import SummaryAssemblerOperator class TriggerReqBody(BaseModel): - file_path: str = Field(..., description="file_path") + url: str = Field(..., description="url") class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]): @@ -47,7 +46,7 @@ def __init__(self, **kwargs): async def map(self, input_value: TriggerReqBody) -> Dict: params = { - "file_path": input_value.file_path, + "url": input_value.url, } print(f"Receive input value: {input_value}") return params @@ -58,9 +57,9 @@ async def map(self, input_value: TriggerReqBody) -> Dict: "/examples/rag/summary", methods="POST", request_body=TriggerReqBody ) request_handle_task = RequestHandleOperator() - path_operator = MapOperator(lambda request: request["file_path"]) + path_operator = MapOperator(lambda request: request["url"]) # build knowledge operator - knowledge_operator = KnowledgeOperator() + knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL) # build summary assembler operator summary_operator = SummaryAssemblerOperator( llm_client=OpenAILLMClient(), language="en"