1
+ import json
1
2
import logging
2
3
import traceback
3
- from typing import List , Optional , cast
4
+ from typing import Any , List , Optional , cast
4
5
5
6
from fastapi import HTTPException
6
7
14
15
from dbgpt .core .awel .dag .dag_manager import DAGManager
15
16
from dbgpt .core .awel .flow .flow_factory import FlowCategory , FlowFactory
16
17
from dbgpt .core .awel .trigger .http_trigger import CommonLLMHttpTrigger
18
+ from dbgpt .core .interface .llm import ModelOutput
17
19
from dbgpt .serve .core import BaseService
18
20
from dbgpt .storage .metadata import BaseDao
19
21
from dbgpt .storage .metadata ._base_dao import QUERY_SPEC
@@ -276,12 +278,39 @@ def get_list_by_page(
276
278
"""
277
279
return self .dao .get_list_page (request , page , page_size )
278
280
279
- async def chat_flow (self , flow_uid : str , request : CommonLLMHttpRequestBody ):
281
+ async def chat_flow (
282
+ self ,
283
+ flow_uid : str ,
284
+ request : CommonLLMHttpRequestBody ,
285
+ incremental : bool = False ,
286
+ ):
280
287
"""Chat with the AWEL flow.
281
288
282
289
Args:
283
290
flow_uid (str): The flow uid
284
291
request (CommonLLMHttpRequestBody): The request
292
+ incremental (bool): Whether to return the result incrementally
293
+ """
294
+ try :
295
+ async for output in self ._call_chat_flow (flow_uid , request , incremental ):
296
+ yield output
297
+ except HTTPException as e :
298
+ yield f"data:[SERVER_ERROR]{ e .detail } \n \n "
299
+ except Exception as e :
300
+ yield f"data:[SERVER_ERROR]{ str (e )} \n \n "
301
+
302
+ async def _call_chat_flow (
303
+ self ,
304
+ flow_uid : str ,
305
+ request : CommonLLMHttpRequestBody ,
306
+ incremental : bool = False ,
307
+ ):
308
+ """Chat with the AWEL flow.
309
+
310
+ Args:
311
+ flow_uid (str): The flow uid
312
+ request (CommonLLMHttpRequestBody): The request
313
+ incremental (bool): Whether to return the result incrementally
285
314
"""
286
315
flow = self .get ({"uid" : flow_uid })
287
316
if not flow :
@@ -291,18 +320,18 @@ async def chat_flow(self, flow_uid: str, request: CommonLLMHttpRequestBody):
291
320
raise HTTPException (
292
321
status_code = 404 , detail = f"Flow { flow_uid } 's dag id not found"
293
322
)
294
- if flow .flow_category != FlowCategory .CHAT_FLOW :
295
- raise ValueError (f"Flow { flow_uid } is not a chat flow" )
296
323
dag = self .dag_manager .dag_map [dag_id ]
324
+ if (
325
+ flow .flow_category != FlowCategory .CHAT_FLOW
326
+ and self ._parse_flow_category (dag ) != FlowCategory .CHAT_FLOW
327
+ ):
328
+ raise ValueError (f"Flow { flow_uid } is not a chat flow" )
297
329
leaf_nodes = dag .leaf_nodes
298
330
if len (leaf_nodes ) != 1 :
299
331
raise ValueError ("Chat Flow just support one leaf node in dag" )
300
332
end_node = cast (BaseOperator , leaf_nodes [0 ])
301
- if request .stream :
302
- async for output in await end_node .call_stream (request ):
303
- yield output
304
- else :
305
- yield await end_node .call (request )
333
+ async for output in _chat_with_dag_task (end_node , request , incremental ):
334
+ yield output
306
335
307
336
def _parse_flow_category (self , dag : DAG ) -> FlowCategory :
308
337
"""Parse the flow category
@@ -335,9 +364,104 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory:
335
364
output = leaf_node .metadata .outputs [0 ]
336
365
try :
337
366
real_class = _get_type_cls (output .type_cls )
338
- if common_http_trigger and (
339
- real_class == str or real_class == CommonLLMHttpResponseBody
340
- ):
367
+ if common_http_trigger and _is_chat_flow_type (real_class , is_class = True ):
341
368
return FlowCategory .CHAT_FLOW
342
369
except Exception :
343
370
return FlowCategory .COMMON
371
+
372
+
373
+ def _is_chat_flow_type (obj : Any , is_class : bool = False ) -> bool :
374
+ try :
375
+ from dbgpt .model .utils .chatgpt_utils import OpenAIStreamingOutputOperator
376
+ except ImportError :
377
+ OpenAIStreamingOutputOperator = None
378
+ if is_class :
379
+ return (
380
+ obj == str
381
+ or obj == CommonLLMHttpResponseBody
382
+ or (OpenAIStreamingOutputOperator and obj == OpenAIStreamingOutputOperator )
383
+ )
384
+ else :
385
+ chat_types = (str , CommonLLMHttpResponseBody )
386
+ if OpenAIStreamingOutputOperator :
387
+ chat_types += (OpenAIStreamingOutputOperator ,)
388
+ return isinstance (obj , chat_types )
389
+
390
+
391
+ async def _chat_with_dag_task (
392
+ task : BaseOperator ,
393
+ request : CommonLLMHttpRequestBody ,
394
+ incremental : bool = False ,
395
+ ):
396
+ """Chat with the DAG task.
397
+
398
+ Args:
399
+ task (BaseOperator): The task
400
+ request (CommonLLMHttpRequestBody): The request
401
+ """
402
+ if request .stream and task .streaming_operator :
403
+ try :
404
+ from dbgpt .model .utils .chatgpt_utils import OpenAIStreamingOutputOperator
405
+ except ImportError :
406
+ OpenAIStreamingOutputOperator = None
407
+ if incremental :
408
+ async for output in await task .call_stream (request ):
409
+ yield output
410
+ else :
411
+ if OpenAIStreamingOutputOperator and isinstance (
412
+ task , OpenAIStreamingOutputOperator
413
+ ):
414
+ from fastchat .protocol .openai_api_protocol import (
415
+ ChatCompletionResponseStreamChoice ,
416
+ )
417
+
418
+ previous_text = ""
419
+ async for output in await task .call_stream (request ):
420
+ if not isinstance (output , str ):
421
+ yield "data:[SERVER_ERROR]The output is not a stream format\n \n "
422
+ return
423
+ if output == "data: [DONE]\n \n " :
424
+ return
425
+ json_data = "" .join (output .split ("data: " )[1 :])
426
+ dict_data = json .loads (json_data )
427
+ if "choices" not in dict_data :
428
+ error_msg = dict_data .get ("text" , "Unknown error" )
429
+ yield f"data:[SERVER_ERROR]{ error_msg } \n \n "
430
+ return
431
+ choices = dict_data ["choices" ]
432
+ if choices :
433
+ choice = choices [0 ]
434
+ delta_data = ChatCompletionResponseStreamChoice (** choice )
435
+ if delta_data .delta .content :
436
+ previous_text += delta_data .delta .content
437
+ if previous_text :
438
+ full_text = previous_text .replace ("\n " , "\\ n" )
439
+ yield f"data:{ full_text } \n \n "
440
+ else :
441
+ async for output in await task .call_stream (request ):
442
+ if isinstance (output , str ):
443
+ if output .strip ():
444
+ yield output
445
+ else :
446
+ yield "data:[SERVER_ERROR]The output is not a stream format\n \n "
447
+ return
448
+ else :
449
+ result = await task .call (request )
450
+ if result is None :
451
+ yield "data:[SERVER_ERROR]The result is None\n \n "
452
+ elif isinstance (result , str ):
453
+ yield f"data:{ result } \n \n "
454
+ elif isinstance (result , ModelOutput ):
455
+ if result .error_code != 0 :
456
+ yield f"data:[SERVER_ERROR]{ result .text } \n \n "
457
+ else :
458
+ yield f"data:{ result .text } \n \n "
459
+ elif isinstance (result , CommonLLMHttpResponseBody ):
460
+ if result .error_code != 0 :
461
+ yield f"data:[SERVER_ERROR]{ result .text } \n \n "
462
+ else :
463
+ yield f"data:{ result .text } \n \n "
464
+ elif isinstance (result , dict ):
465
+ yield f"data:{ json .dumps (result , ensure_ascii = False )} \n \n "
466
+ else :
467
+ yield f"data:[SERVER_ERROR]The result is not a valid format({ type (result )} )\n \n "
0 commit comments