Skip to content

Commit ab5e1c7

Browse files
authored
feat(core): Support more chat flows (#1180)
1 parent 16fa68d commit ab5e1c7

File tree

10 files changed

+175
-55
lines changed

10 files changed

+175
-55
lines changed

dbgpt/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "0.4.7"
1+
version = "0.5.0"

dbgpt/app/openapi/api_v1/api_v1.py

+1-31
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,7 @@ async def chat_completions(
366366
context=flow_ctx,
367367
)
368368
return StreamingResponse(
369-
flow_stream_generator(
370-
flow_service.chat_flow(dialogue.select_param, flow_req),
371-
dialogue.incremental,
372-
dialogue.model_name,
373-
),
369+
flow_service.chat_flow(dialogue.select_param, flow_req),
374370
headers=headers,
375371
media_type="text/event-stream",
376372
)
@@ -426,32 +422,6 @@ async def no_stream_generator(chat):
426422
yield f"data: {msg}\n\n"
427423

428424

429-
async def flow_stream_generator(func, incremental: bool, model_name: str):
430-
stream_id = f"chatcmpl-{str(uuid.uuid1())}"
431-
previous_response = ""
432-
async for chunk in func:
433-
if chunk:
434-
msg = chunk.replace("\ufffd", "")
435-
if incremental:
436-
incremental_output = msg[len(previous_response) :]
437-
choice_data = ChatCompletionResponseStreamChoice(
438-
index=0,
439-
delta=DeltaMessage(role="assistant", content=incremental_output),
440-
)
441-
chunk = ChatCompletionStreamResponse(
442-
id=stream_id, choices=[choice_data], model=model_name
443-
)
444-
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
445-
else:
446-
# TODO generate an openai-compatible streaming responses
447-
msg = msg.replace("\n", "\\n")
448-
yield f"data:{msg}\n\n"
449-
previous_response = msg
450-
await asyncio.sleep(0.02)
451-
if incremental:
452-
yield "data: [DONE]\n\n"
453-
454-
455425
async def stream_generator(chat, incremental: bool, model_name: str):
456426
"""Generate streaming responses
457427

dbgpt/core/awel/flow/base.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -632,16 +632,36 @@ def get_runnable_parameters(
632632
runnable_parameters: Dict[str, Any] = {}
633633
if not self.parameters or not view_parameters:
634634
return runnable_parameters
635-
if len(self.parameters) != len(view_parameters):
635+
view_required_parameters = {
636+
parameter.name: parameter
637+
for parameter in view_parameters
638+
if not parameter.optional
639+
}
640+
current_required_parameters = {
641+
parameter.name: parameter
642+
for parameter in self.parameters
643+
if not parameter.optional
644+
}
645+
current_parameters = {
646+
parameter.name: parameter for parameter in self.parameters
647+
}
648+
if len(view_required_parameters) < len(current_required_parameters):
636649
# TODO, skip the optional parameters.
637650
raise FlowParameterMetadataException(
638-
f"Parameters count not match. Expected {len(self.parameters)}, "
651+
f"Parameters count not match(current key: {self.id}). "
652+
f"Expected {len(self.parameters)}, "
639653
f"but got {len(view_parameters)} from JSON metadata."
654+
f"Required parameters: {current_required_parameters.keys()}, "
655+
f"but got {view_required_parameters.keys()}."
640656
)
641-
for i, parameter in enumerate(self.parameters):
642-
view_param = view_parameters[i]
657+
for view_param in view_parameters:
658+
view_param_key = view_param.name
659+
if view_param_key not in current_parameters:
660+
raise FlowParameterMetadataException(
661+
f"Parameter {view_param_key} not found in the metadata."
662+
)
643663
runnable_parameters.update(
644-
parameter.to_runnable_parameter(
664+
current_parameters[view_param_key].to_runnable_parameter(
645665
view_param.get_typed_value(), resources, key_to_resource_instance
646666
)
647667
)

dbgpt/core/awel/operators/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
122122
This class extends DAGNode by adding execution capabilities.
123123
"""
124124

125+
streaming_operator: bool = False
126+
125127
def __init__(
126128
self,
127129
task_id: Optional[str] = None,

dbgpt/core/awel/operators/stream_operator.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]):
1111
"""An abstract operator that converts a value of IN to an AsyncIterator[OUT]."""
1212

13+
streaming_operator = True
14+
1315
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
1416
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
1517
call_data = curr_task_ctx.call_data
@@ -83,6 +85,8 @@ class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]):
8385
AsyncIterator[IN] to another AsyncIterator[OUT].
8486
"""
8587

88+
streaming_operator = True
89+
8690
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
8791
curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context
8892
output: TaskOutput[OUT] = await curr_task_ctx.task_input.parent_outputs[

dbgpt/core/interface/operators/prompt_operator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
7474
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
7575
"""Pre fill the messages."""
7676
if "system_message" not in values:
77-
raise ValueError("No system message")
77+
values["system_message"] = "You are a helpful AI Assistant."
7878
if "human_message" not in values:
79-
raise ValueError("No human message")
79+
values["human_message"] = "{user_input}"
8080
if "message_placeholder" not in values:
81-
raise ValueError("No message placeholder")
81+
values["message_placeholder"] = "chat_history"
8282
system_message = values.pop("system_message")
8383
human_message = values.pop("human_message")
8484
message_placeholder = values.pop("message_placeholder")

dbgpt/serve/flow/service/service.py

+136-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import json
12
import logging
23
import traceback
3-
from typing import List, Optional, cast
4+
from typing import Any, List, Optional, cast
45

56
from fastapi import HTTPException
67

@@ -14,6 +15,7 @@
1415
from dbgpt.core.awel.dag.dag_manager import DAGManager
1516
from dbgpt.core.awel.flow.flow_factory import FlowCategory, FlowFactory
1617
from dbgpt.core.awel.trigger.http_trigger import CommonLLMHttpTrigger
18+
from dbgpt.core.interface.llm import ModelOutput
1719
from dbgpt.serve.core import BaseService
1820
from dbgpt.storage.metadata import BaseDao
1921
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
@@ -276,12 +278,39 @@ def get_list_by_page(
276278
"""
277279
return self.dao.get_list_page(request, page, page_size)
278280

279-
async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
281+
async def chat_flow(
282+
self,
283+
flow_uid: str,
284+
request: CommonLLMHttpRequestBody,
285+
incremental: bool = False,
286+
):
280287
"""Chat with the AWEL flow.
281288
282289
Args:
283290
flow_uid (str): The flow uid
284291
request (CommonLLMHttpRequestBody): The request
292+
incremental (bool): Whether to return the result incrementally
293+
"""
294+
try:
295+
async for output in self._call_chat_flow(flow_uid, request, incremental):
296+
yield output
297+
except HTTPException as e:
298+
yield f"data:[SERVER_ERROR]{e.detail}\n\n"
299+
except Exception as e:
300+
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
301+
302+
async def _call_chat_flow(
303+
self,
304+
flow_uid: str,
305+
request: CommonLLMHttpRequestBody,
306+
incremental: bool = False,
307+
):
308+
"""Chat with the AWEL flow.
309+
310+
Args:
311+
flow_uid (str): The flow uid
312+
request (CommonLLMHttpRequestBody): The request
313+
incremental (bool): Whether to return the result incrementally
285314
"""
286315
flow = self.get({"uid": flow_uid})
287316
if not flow:
@@ -291,18 +320,18 @@ async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
291320
raise HTTPException(
292321
status_code=404, detail=f"Flow {flow_uid}'s dag id not found"
293322
)
294-
if flow.flow_category != FlowCategory.CHAT_FLOW:
295-
raise ValueError(f"Flow {flow_uid} is not a chat flow")
296323
dag = self.dag_manager.dag_map[dag_id]
324+
if (
325+
flow.flow_category != FlowCategory.CHAT_FLOW
326+
and self._parse_flow_category(dag) != FlowCategory.CHAT_FLOW
327+
):
328+
raise ValueError(f"Flow {flow_uid} is not a chat flow")
297329
leaf_nodes = dag.leaf_nodes
298330
if len(leaf_nodes) != 1:
299331
raise ValueError("Chat Flow just support one leaf node in dag")
300332
end_node = cast(BaseOperator, leaf_nodes[0])
301-
if request.stream:
302-
async for output in await end_node.call_stream(request):
303-
yield output
304-
else:
305-
yield await end_node.call(request)
333+
async for output in _chat_with_dag_task(end_node, request, incremental):
334+
yield output
306335

307336
def _parse_flow_category(self, dag: DAG) -> FlowCategory:
308337
"""Parse the flow category
@@ -335,9 +364,104 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory:
335364
output = leaf_node.metadata.outputs[0]
336365
try:
337366
real_class = _get_type_cls(output.type_cls)
338-
if common_http_trigger and (
339-
real_class == str or real_class == CommonLLMHttpResponseBody
340-
):
367+
if common_http_trigger and _is_chat_flow_type(real_class, is_class=True):
341368
return FlowCategory.CHAT_FLOW
342369
except Exception:
343370
return FlowCategory.COMMON
371+
372+
373+
def _is_chat_flow_type(obj: Any, is_class: bool = False) -> bool:
374+
try:
375+
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
376+
except ImportError:
377+
OpenAIStreamingOutputOperator = None
378+
if is_class:
379+
return (
380+
obj == str
381+
or obj == CommonLLMHttpResponseBody
382+
or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator)
383+
)
384+
else:
385+
chat_types = (str, CommonLLMHttpResponseBody)
386+
if OpenAIStreamingOutputOperator:
387+
chat_types += (OpenAIStreamingOutputOperator,)
388+
return isinstance(obj, chat_types)
389+
390+
391+
async def _chat_with_dag_task(
392+
task: BaseOperator,
393+
request: CommonLLMHttpRequestBody,
394+
incremental: bool = False,
395+
):
396+
"""Chat with the DAG task.
397+
398+
Args:
399+
task (BaseOperator): The task
400+
request (CommonLLMHttpRequestBody): The request
401+
"""
402+
if request.stream and task.streaming_operator:
403+
try:
404+
from dbgpt.model.utils.chatgpt_utils import OpenAIStreamingOutputOperator
405+
except ImportError:
406+
OpenAIStreamingOutputOperator = None
407+
if incremental:
408+
async for output in await task.call_stream(request):
409+
yield output
410+
else:
411+
if OpenAIStreamingOutputOperator and isinstance(
412+
task, OpenAIStreamingOutputOperator
413+
):
414+
from fastchat.protocol.openai_api_protocol import (
415+
ChatCompletionResponseStreamChoice,
416+
)
417+
418+
previous_text = ""
419+
async for output in await task.call_stream(request):
420+
if not isinstance(output, str):
421+
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
422+
return
423+
if output == "data: [DONE]\n\n":
424+
return
425+
json_data = "".join(output.split("data: ")[1:])
426+
dict_data = json.loads(json_data)
427+
if "choices" not in dict_data:
428+
error_msg = dict_data.get("text", "Unknown error")
429+
yield f"data:[SERVER_ERROR]{error_msg}\n\n"
430+
return
431+
choices = dict_data["choices"]
432+
if choices:
433+
choice = choices[0]
434+
delta_data = ChatCompletionResponseStreamChoice(**choice)
435+
if delta_data.delta.content:
436+
previous_text += delta_data.delta.content
437+
if previous_text:
438+
full_text = previous_text.replace("\n", "\\n")
439+
yield f"data:{full_text}\n\n"
440+
else:
441+
async for output in await task.call_stream(request):
442+
if isinstance(output, str):
443+
if output.strip():
444+
yield output
445+
else:
446+
yield "data:[SERVER_ERROR]The output is not a stream format\n\n"
447+
return
448+
else:
449+
result = await task.call(request)
450+
if result is None:
451+
yield "data:[SERVER_ERROR]The result is None\n\n"
452+
elif isinstance(result, str):
453+
yield f"data:{result}\n\n"
454+
elif isinstance(result, ModelOutput):
455+
if result.error_code != 0:
456+
yield f"data:[SERVER_ERROR]{result.text}\n\n"
457+
else:
458+
yield f"data:{result.text}\n\n"
459+
elif isinstance(result, CommonLLMHttpResponseBody):
460+
if result.error_code != 0:
461+
yield f"data:[SERVER_ERROR]{result.text}\n\n"
462+
else:
463+
yield f"data:{result.text}\n\n"
464+
elif isinstance(result, dict):
465+
yield f"data:{json.dumps(result, ensure_ascii=False)}\n\n"
466+
else:
467+
yield f"data:[SERVER_ERROR]The result is not a valid format({type(result)})\n\n"

dbgpt/util/dbgpts/repo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def update_repo(repo: str):
140140
logger.info(f"Repo '{repo}' is not a git repository.")
141141
return
142142
logger.info(f"Updating repo '{repo}'...")
143-
subprocess.run(["git", "pull"], check=True)
143+
subprocess.run(["git", "pull"], check=False)
144144

145145

146146
def install(

docs/docs/upgrade/v0.5.0.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Upgrade To v0.5.0(Draft)
1+
# Upgrade To v0.5.0
22

33
## Overview
44

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
1919
# If you modify the version, please modify the version in the following files:
2020
# dbgpt/_version.py
21-
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.4.7")
21+
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.0")
2222

2323
BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
2424
LLAMA_CPP_GPU_ACCELERATION = (

0 commit comments

Comments
 (0)