Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Support more chat flows #1180

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbgpt/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.4.7"
version = "0.5.0"
32 changes: 1 addition & 31 deletions dbgpt/app/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,7 @@ async def chat_completions(
context=flow_ctx,
)
return StreamingResponse(
flow_stream_generator(
flow_service.chat_flow(dialogue.select_param, flow_req),
dialogue.incremental,
dialogue.model_name,
),
flow_service.chat_flow(dialogue.select_param, flow_req),
headers=headers,
media_type="text/event-stream",
)
Expand Down Expand Up @@ -426,32 +422,6 @@ async def no_stream_generator(chat):
yield f"data: {msg}\n\n"


async def flow_stream_generator(func, incremental: bool, model_name: str):
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
previous_response = ""
async for chunk in func:
if chunk:
msg = chunk.replace("\ufffd", "")
if incremental:
incremental_output = msg[len(previous_response) :]
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=incremental_output),
)
chunk = ChatCompletionStreamResponse(
id=stream_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
else:
# TODO generate an openai-compatible streaming responses
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
previous_response = msg
await asyncio.sleep(0.02)
if incremental:
yield "data: [DONE]\n\n"


async def stream_generator(chat, incremental: bool, model_name: str):
"""Generate streaming responses

Expand Down
30 changes: 25 additions & 5 deletions dbgpt/core/awel/flow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,16 +632,36 @@ def get_runnable_parameters(
runnable_parameters: Dict[str, Any] = {}
if not self.parameters or not view_parameters:
return runnable_parameters
if len(self.parameters) != len(view_parameters):
view_required_parameters = {
parameter.name: parameter
for parameter in view_parameters
if not parameter.optional
}
current_required_parameters = {
parameter.name: parameter
for parameter in self.parameters
if not parameter.optional
}
current_parameters = {
parameter.name: parameter for parameter in self.parameters
}
if len(view_required_parameters) < len(current_required_parameters):
# TODO, skip the optional parameters.
raise FlowParameterMetadataException(
f"Parameters count not match. Expected {len(self.parameters)}, "
f"Parameters count not match(current key: {self.id}). "
f"Expected {len(self.parameters)}, "
f"but got {len(view_parameters)} from JSON metadata."
f"Required parameters: {current_required_parameters.keys()}, "
f"but got {view_required_parameters.keys()}."
)
for i, parameter in enumerate(self.parameters):
view_param = view_parameters[i]
for view_param in view_parameters:
view_param_key = view_param.name
if view_param_key not in current_parameters:
raise FlowParameterMetadataException(
f"Parameter {view_param_key} not found in the metadata."
)
runnable_parameters.update(
parameter.to_runnable_parameter(
current_parameters[view_param_key].to_runnable_parameter(
view_param.get_typed_value(), resources, key_to_resource_instance
)
)
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
This class extends DAGNode by adding execution capabilities.
"""

streaming_operator: bool = False

def __init__(
self,
task_id: Optional[str] = None,
Expand Down
4 changes: 4 additions & 0 deletions dbgpt/core/awel/operators/stream_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""

streaming_operator = True

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
call_data = curr_task_ctx.call_data
Expand Down Expand Up @@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
AsyncIterator[IN] to another AsyncIterator[OUT].
"""

streaming_operator = True

async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/core/interface/operators/prompt_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the messages."""
if "system_message" not in values:
raise ValueError("No system message")
values["system_message"] = "You are a helpful AI Assistant."
if "human_message" not in values:
raise ValueError("No human message")
values["human_message"] = "{user_input}"
if "message_placeholder" not in values:
raise ValueError("No message placeholder")
values["message_placeholder"] = "chat_history"
system_message = values.pop("system_message")
human_message = values.pop("human_message")
message_placeholder = values.pop("message_placeholder")
Expand Down
148 changes: 136 additions & 12 deletions dbgpt/serve/flow/service/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import traceback
from typing import List, Optional, cast
from typing import Any, List, Optional, cast

from fastapi import HTTPException

Expand All @@ -14,6 +15,7 @@
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
Expand Down Expand Up @@ -276,12 +278,39 @@ def get_list_by_page(
"""
return self.dao.get_list_page(request, page, page_size)

async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
async def chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow.

Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
"""
try:
async for output in self._call_chat_flow(flow_uid, request, incremental):
yield output
except HTTPException as e:
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
except Exception as e:
yield f"data:[SERVER_ERROR]{str(e)}\n\n"

async def _call_chat_flow(
self,
flow_uid: str,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the AWEL flow.

Args:
flow_uid (str): The flow uid
request (CommonLLMHttpRequestBody): The request
incremental (bool): Whether to return the result incrementally
"""
flow = self.get({"uid": flow_uid})
if not flow:
Expand All @@ -291,18 +320,18 @@ async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
raise HTTPException(
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
)
if flow.flow_category != FlowCategory.CHAT_FLOW:
raise ValueError(f"Flow {flow_uid} is not a chat flow")
dag = self.dag_manager.dag_map[dag_id]
if (
flow.flow_category != FlowCategory.CHAT_FLOW
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
):
raise ValueError(f"Flow {flow_uid} is not a chat flow")
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0])
if request.stream:
async for output in await end_node.call_stream(request):
yield output
else:
yield await end_node.call(request)
async for output in _chat_with_dag_task(end_node, request, incremental):
yield output

def _parse_flow_category(self, dag: DAG) -> FlowCategory:
"""Parse the flow category
Expand Down Expand Up @@ -335,9 +364,104 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory:
output = leaf_node.metadata.outputs[0]
try:
real_class = _get_type_cls(output.type_cls)
if common_http_trigger and (
real_class == str or real_class == CommonLLMHttpResponseBody
):
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON


def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if is_class:
return (
obj == str
or obj == CommonLLMHttpResponseBody
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
)
else:
chat_types = (str, CommonLLMHttpResponseBody)
if OpenAIStreamingOutputOperator:
chat_types += (OpenAIStreamingOutputOperator,)
return isinstance(obj, chat_types)


async def _chat_with_dag_task(
task: BaseOperator,
request: CommonLLMHttpRequestBody,
incremental: bool = False,
):
"""Chat with the DAG task.

Args:
task (BaseOperator): The task
request (CommonLLMHttpRequestBody): The request
"""
if request.stream and task.streaming_operator:
try:
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
except ImportError:
OpenAIStreamingOutputOperator = None
if incremental:
async for output in await task.call_stream(request):
yield output
else:
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
)

previous_text = ""
async for output in await task.call_stream(request):
if not isinstance(output, str):
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
if output == "data: [DONE]\n\n":
return
json_data = "".join(output.split("data: ")[1:])
dict_data = json.loads(json_data)
if "choices" not in dict_data:
error_msg = dict_data.get("text", "Unknown error")
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
return
choices = dict_data["choices"]
if choices:
choice = choices[0]
delta_data = ChatCompletionResponseStreamChoice(**choice)
if delta_data.delta.content:
previous_text += delta_data.delta.content
if previous_text:
full_text = previous_text.replace("\n", "\\n")
yield f"data:{full_text}\n\n"
else:
async for output in await task.call_stream(request):
if isinstance(output, str):
if output.strip():
yield output
else:
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
return
else:
result = await task.call(request)
if result is None:
yield "data:[SERVER_ERROR]The result is None\n\n"
elif isinstance(result, str):
yield f"data:{result}\n\n"
elif isinstance(result, ModelOutput):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
else:
yield f"data:{result.text}\n\n"
elif isinstance(result, CommonLLMHttpResponseBody):
if result.error_code != 0:
yield f"data:[SERVER_ERROR]{result.text}\n\n"
else:
yield f"data:{result.text}\n\n"
elif isinstance(result, dict):
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
else:
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"
2 changes: 1 addition & 1 deletion dbgpt/util/dbgpts/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def update_repo(repo: str):
logger.info(f"Repo '{repo}' is not a git repository.")
return
logger.info(f"Updating repo '{repo}'...")
subprocess.run(["git", "pull"], check=True)
subprocess.run(["git", "pull"], check=False)


def install(
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/upgrade/v0.5.0.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Upgrade To v0.5.0(Draft)
# Upgrade To v0.5.0

## Overview

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
# If you modify the version, please modify the version in the following files:
# dbgpt/_version.py
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7")
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.0")

BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
LLAMA_CPP_GPU_ACCELERATION = (
Expand Down
Loading