From 3b6da64ebbbb1321c48bf990f7a486cc2c7a2d23 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 22 Feb 2024 12:19:04 +0800 Subject: [PATCH] feat(core): Support more chat flows (#1180) --- dbgpt/_version.py | 2 +- dbgpt/app/openapi/api_v1/api_v1.py | 32 +--- dbgpt/core/awel/flow/base.py | 30 +++- dbgpt/core/awel/operators/base.py | 2 + dbgpt/core/awel/operators/stream_operator.py | 4 + .../interface/operators/prompt_operator.py | 6 +- dbgpt/serve/flow/service/service.py | 148 ++++++++++++++++-- dbgpt/util/dbgpts/repo.py | 2 +- docs/docs/upgrade/v0.5.0.md | 2 +- setup.py | 2 +- 10 files changed, 175 insertions(+), 55 deletions(-) diff --git a/dbgpt/_version.py b/dbgpt/_version.py index e7cc0113f..819a3005a 100644 --- a/dbgpt/_version.py +++ b/dbgpt/_version.py @@ -1 +1 @@ -version = "0.4.7" +version = "0.5.0" diff --git a/dbgpt/app/openapi/api_v1/api_v1.py b/dbgpt/app/openapi/api_v1/api_v1.py index 67c545812..0762feb17 100644 --- a/dbgpt/app/openapi/api_v1/api_v1.py +++ b/dbgpt/app/openapi/api_v1/api_v1.py @@ -366,11 +366,7 @@ async def chat_completions( context=flow_ctx, ) return StreamingResponse( - flow_stream_generator( - flow_service.chat_flow(dialogue.select_param, flow_req), - dialogue.incremental, - dialogue.model_name, - ), + flow_service.chat_flow(dialogue.select_param, flow_req), headers=headers, media_type="text/event-stream", ) @@ -426,32 +422,6 @@ async def no_stream_generator(chat): yield f"data: {msg}\n\n" -async def flow_stream_generator(func, incremental: bool, model_name: str): - stream_id = f"chatcmpl-{str(uuid.uuid1())}" - previous_response = "" - async for chunk in func: - if chunk: - msg = chunk.replace("\ufffd", "") - if incremental: - incremental_output = msg[len(previous_response) :] - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(role="assistant", content=incremental_output), - ) - chunk = ChatCompletionStreamResponse( - id=stream_id, choices=[choice_data], model=model_name - ) - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" - else: - # TODO generate an openai-compatible streaming responses - msg = msg.replace("\n", "\\n") - yield f"data:{msg}\n\n" - previous_response = msg - await asyncio.sleep(0.02) - if incremental: - yield "data: [DONE]\n\n" - - async def stream_generator(chat, incremental: bool, model_name: str): """Generate streaming responses diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 835b12632..e573b0567 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -632,16 +632,36 @@ def get_runnable_parameters( runnable_parameters: Dict[str, Any] = {} if not self.parameters or not view_parameters: return runnable_parameters - if len(self.parameters) != len(view_parameters): + view_required_parameters = { + parameter.name: parameter + for parameter in view_parameters + if not parameter.optional + } + current_required_parameters = { + parameter.name: parameter + for parameter in self.parameters + if not parameter.optional + } + current_parameters = { + parameter.name: parameter for parameter in self.parameters + } + if len(view_required_parameters) < len(current_required_parameters): # TODO, skip the optional parameters. raise FlowParameterMetadataException( - f"Parameters count not match. Expected {len(self.parameters)}, " + f"Parameters count not match(current key: {self.id}). " + f"Expected {len(self.parameters)}, " f"but got {len(view_parameters)} from JSON metadata." + f"Required parameters: {current_required_parameters.keys()}, " + f"but got {view_required_parameters.keys()}." ) - for i, parameter in enumerate(self.parameters): - view_param = view_parameters[i] + for view_param in view_parameters: + view_param_key = view_param.name + if view_param_key not in current_parameters: + raise FlowParameterMetadataException( + f"Parameter {view_param_key} not found in the metadata." + ) runnable_parameters.update( - parameter.to_runnable_parameter( + current_parameters[view_param_key].to_runnable_parameter( view_param.get_typed_value(), resources, key_to_resource_instance ) ) diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index d6ebc85d9..495e7bc2e 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): This class extends DAGNode by adding execution capabilities. """ + streaming_operator: bool = False + def __init__( self, task_id: Optional[str] = None, diff --git a/dbgpt/core/awel/operators/stream_operator.py b/dbgpt/core/awel/operators/stream_operator.py index 8893a51f9..79c0273d7 100644 --- a/dbgpt/core/awel/operators/stream_operator.py +++ b/dbgpt/core/awel/operators/stream_operator.py @@ -10,6 +10,8 @@ class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]): """An abstract operator that converts a value of IN to an AsyncIterator[OUT].""" + streaming_operator = True + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context call_data = curr_task_ctx.call_data @@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]): AsyncIterator[IN] to another AsyncIterator[OUT]. """ + streaming_operator = True + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[ diff --git a/dbgpt/core/interface/operators/prompt_operator.py b/dbgpt/core/interface/operators/prompt_operator.py index 0018101e8..4a965b628 100644 --- a/dbgpt/core/interface/operators/prompt_operator.py +++ b/dbgpt/core/interface/operators/prompt_operator.py @@ -74,11 +74,11 @@ class CommonChatPromptTemplate(ChatPromptTemplate): def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the messages.""" if "system_message" not in values: - raise ValueError("No system message") + values["system_message"] = "You are a helpful AI Assistant." if "human_message" not in values: - raise ValueError("No human message") + values["human_message"] = "{user_input}" if "message_placeholder" not in values: - raise ValueError("No message placeholder") + values["message_placeholder"] = "chat_history" system_message = values.pop("system_message") human_message = values.pop("human_message") message_placeholder = values.pop("message_placeholder") diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 9ef80118c..0ed5e462c 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -1,6 +1,7 @@ +import json import logging import traceback -from typing import List, Optional, cast +from typing import Any, List, Optional, cast from fastapi import HTTPException @@ -14,6 +15,7 @@ from dbgpt.core.awel.dag.dag_manager import DAGManager from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger +from dbgpt.core.interface.llm import ModelOutput from dbgpt.serve.core import BaseService from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata._base_dao import QUERY_SPEC @@ -276,12 +278,39 @@ def get_list_by_page( """ return self.dao.get_list_page(request, page, page_size) - async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody): + async def chat_flow( + self, + flow_uid: str, + request: CommonLLMHttpRequestBody, + incremental: bool = False, + ): """Chat with the AWEL flow. Args: flow_uid (str): The flow uid request (CommonLLMHttpRequestBody): The request + incremental (bool): Whether to return the result incrementally + """ + try: + async for output in self._call_chat_flow(flow_uid, request, incremental): + yield output + except HTTPException as e: + yield f"data:[SERVER_ERROR]{e.detail}\n\n" + except Exception as e: + yield f"data:[SERVER_ERROR]{str(e)}\n\n" + + async def _call_chat_flow( + self, + flow_uid: str, + request: CommonLLMHttpRequestBody, + incremental: bool = False, + ): + """Chat with the AWEL flow. + + Args: + flow_uid (str): The flow uid + request (CommonLLMHttpRequestBody): The request + incremental (bool): Whether to return the result incrementally """ flow = self.get({"uid": flow_uid}) if not flow: @@ -291,18 +320,18 @@ async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody): raise HTTPException( status_code=404, detail=f"Flow {flow_uid}'s dag id not found" ) - if flow.flow_category != FlowCategory.CHAT_FLOW: - raise ValueError(f"Flow {flow_uid} is not a chat flow") dag = self.dag_manager.dag_map[dag_id] + if ( + flow.flow_category != FlowCategory.CHAT_FLOW + and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW + ): + raise ValueError(f"Flow {flow_uid} is not a chat flow") leaf_nodes = dag.leaf_nodes if len(leaf_nodes) != 1: raise ValueError("Chat Flow just support one leaf node in dag") end_node = cast(BaseOperator, leaf_nodes[0]) - if request.stream: - async for output in await end_node.call_stream(request): - yield output - else: - yield await end_node.call(request) + async for output in _chat_with_dag_task(end_node, request, incremental): + yield output def _parse_flow_category(self, dag: DAG) -> FlowCategory: """Parse the flow category @@ -335,9 +364,104 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory: output = leaf_node.metadata.outputs[0] try: real_class = _get_type_cls(output.type_cls) - if common_http_trigger and ( - real_class == str or real_class == CommonLLMHttpResponseBody - ): + if common_http_trigger and _is_chat_flow_type(real_class, is_class=True): return FlowCategory.CHAT_FLOW except Exception: return FlowCategory.COMMON + + +def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool: + try: + from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator + except ImportError: + OpenAIStreamingOutputOperator = None + if is_class: + return ( + obj == str + or obj == CommonLLMHttpResponseBody + or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator) + ) + else: + chat_types = (str, CommonLLMHttpResponseBody) + if OpenAIStreamingOutputOperator: + chat_types += (OpenAIStreamingOutputOperator,) + return isinstance(obj, chat_types) + + +async def _chat_with_dag_task( + task: BaseOperator, + request: CommonLLMHttpRequestBody, + incremental: bool = False, +): + """Chat with the DAG task. + + Args: + task (BaseOperator): The task + request (CommonLLMHttpRequestBody): The request + """ + if request.stream and task.streaming_operator: + try: + from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator + except ImportError: + OpenAIStreamingOutputOperator = None + if incremental: + async for output in await task.call_stream(request): + yield output + else: + if OpenAIStreamingOutputOperator and isinstance( + task, OpenAIStreamingOutputOperator + ): + from fastchat.protocol.openai_api_protocol import ( + ChatCompletionResponseStreamChoice, + ) + + previous_text = "" + async for output in await task.call_stream(request): + if not isinstance(output, str): + yield "data:[SERVER_ERROR]The output is not a stream format\n\n" + return + if output == "data: [DONE]\n\n": + return + json_data = "".join(output.split("data: ")[1:]) + dict_data = json.loads(json_data) + if "choices" not in dict_data: + error_msg = dict_data.get("text", "Unknown error") + yield f"data:[SERVER_ERROR]{error_msg}\n\n" + return + choices = dict_data["choices"] + if choices: + choice = choices[0] + delta_data = ChatCompletionResponseStreamChoice(**choice) + if delta_data.delta.content: + previous_text += delta_data.delta.content + if previous_text: + full_text = previous_text.replace("\n", "\\n") + yield f"data:{full_text}\n\n" + else: + async for output in await task.call_stream(request): + if isinstance(output, str): + if output.strip(): + yield output + else: + yield "data:[SERVER_ERROR]The output is not a stream format\n\n" + return + else: + result = await task.call(request) + if result is None: + yield "data:[SERVER_ERROR]The result is None\n\n" + elif isinstance(result, str): + yield f"data:{result}\n\n" + elif isinstance(result, ModelOutput): + if result.error_code != 0: + yield f"data:[SERVER_ERROR]{result.text}\n\n" + else: + yield f"data:{result.text}\n\n" + elif isinstance(result, CommonLLMHttpResponseBody): + if result.error_code != 0: + yield f"data:[SERVER_ERROR]{result.text}\n\n" + else: + yield f"data:{result.text}\n\n" + elif isinstance(result, dict): + yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n" + else: + yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n" diff --git a/dbgpt/util/dbgpts/repo.py b/dbgpt/util/dbgpts/repo.py index 7ab471821..e86c836a7 100644 --- a/dbgpt/util/dbgpts/repo.py +++ b/dbgpt/util/dbgpts/repo.py @@ -140,7 +140,7 @@ def update_repo(repo: str): logger.info(f"Repo '{repo}' is not a git repository.") return logger.info(f"Updating repo '{repo}'...") - subprocess.run(["git", "pull"], check=True) + subprocess.run(["git", "pull"], check=False) def install( diff --git a/docs/docs/upgrade/v0.5.0.md b/docs/docs/upgrade/v0.5.0.md index 97d8d8e93..e8b24af20 100644 --- a/docs/docs/upgrade/v0.5.0.md +++ b/docs/docs/upgrade/v0.5.0.md @@ -1,4 +1,4 @@ -# Upgrade To v0.5.0(Draft) +# Upgrade To v0.5.0 ## Overview diff --git a/setup.py b/setup.py index e17eb48c9..a8074296f 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true" # If you modify the version, please modify the version in the following files: # dbgpt/_version.py -DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7") +DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.0") BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true" LLAMA_CPP_GPU_ACCELERATION = (