1
1
"""Http trigger for AWEL."""
2
- from __future__ import annotations
3
-
4
2
import logging
5
3
from enum import Enum
6
4
from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Type , Union , cast
7
5
8
- from starlette .requests import Request
9
-
10
6
from dbgpt ._private .pydantic import BaseModel
11
7
12
8
from ..dag .base import DAG
15
11
16
12
if TYPE_CHECKING :
17
13
from fastapi import APIRouter
14
+ from starlette .requests import Request
18
15
19
- RequestBody = Union [Type [Request ], Type [BaseModel ], Type [str ]]
20
- StreamingPredictFunc = Callable [[Union [Request , BaseModel ]], bool ]
16
+ RequestBody = Union [Type [Request ], Type [BaseModel ], Type [str ]]
17
+ StreamingPredictFunc = Callable [[Union [Request , BaseModel ]], bool ]
21
18
22
19
logger = logging .getLogger (__name__ )
23
20
@@ -32,9 +29,9 @@ def __init__(
32
29
self ,
33
30
endpoint : str ,
34
31
methods : Optional [Union [str , List [str ]]] = "GET" ,
35
- request_body : Optional [RequestBody ] = None ,
32
+ request_body : Optional [" RequestBody" ] = None ,
36
33
streaming_response : bool = False ,
37
- streaming_predict_func : Optional [StreamingPredictFunc ] = None ,
34
+ streaming_predict_func : Optional [" StreamingPredictFunc" ] = None ,
38
35
response_model : Optional [Type ] = None ,
39
36
response_headers : Optional [Dict [str , str ]] = None ,
40
37
response_media_type : Optional [str ] = None ,
@@ -69,6 +66,7 @@ def mount_to_router(self, router: "APIRouter") -> None:
69
66
router (APIRouter): The router to mount the trigger.
70
67
"""
71
68
from fastapi import Depends
69
+ from starlette .requests import Request
72
70
73
71
methods = [self ._methods ] if isinstance (self ._methods , str ) else self ._methods
74
72
@@ -114,8 +112,10 @@ async def route_function(body=Depends(_request_body_dependency)):
114
112
115
113
116
114
async def _parse_request_body (
117
- request : Request , request_body_cls : Optional [RequestBody ]
115
+ request : " Request" , request_body_cls : Optional [" RequestBody" ]
118
116
):
117
+ from starlette .requests import Request
118
+
119
119
if not request_body_cls :
120
120
return None
121
121
if request_body_cls == Request :
@@ -152,7 +152,7 @@ async def _trigger_dag(
152
152
raise ValueError ("HttpTrigger just support one leaf node in dag" )
153
153
end_node = cast (BaseOperator , leaf_nodes [0 ])
154
154
if not streaming_response :
155
- return await end_node .call (call_data = { "data" : body } )
155
+ return await end_node .call (call_data = body )
156
156
else :
157
157
headers = response_headers
158
158
media_type = response_media_type if response_media_type else "text/event-stream"
@@ -163,7 +163,7 @@ async def _trigger_dag(
163
163
"Connection" : "keep-alive" ,
164
164
"Transfer-Encoding" : "chunked" ,
165
165
}
166
- generator = await end_node .call_stream (call_data = { "data" : body } )
166
+ generator = await end_node .call_stream (call_data = body )
167
167
background_tasks = BackgroundTasks ()
168
168
background_tasks .add_task (dag ._after_dag_end )
169
169
return StreamingResponse (
0 commit comments