Skip to content

Commit

Permalink
fix(core): Fix bug of sharing data across DAGs (eosphoros-ai#1102)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored and Hopshine committed Sep 10, 2024
1 parent 5c1df87 commit 7ae762f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 57 deletions.
16 changes: 8 additions & 8 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,27 +446,25 @@ class DAGContext:

def __init__(
self,
node_to_outputs: Dict[str, TaskContext],
share_data: Dict[str, Any],
streaming_call: bool = False,
node_to_outputs: Optional[Dict[str, TaskContext]] = None,
node_name_to_ids: Optional[Dict[str, str]] = None,
) -> None:
"""Initialize a DAGContext.
Args:
node_to_outputs (Dict[str, TaskContext]): The task outputs of current DAG.
share_data (Dict[str, Any]): The share data of current DAG.
streaming_call (bool, optional): Whether the current DAG is streaming call.
Defaults to False.
node_to_outputs (Optional[Dict[str, TaskContext]], optional):
The task outputs of current DAG. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional):
The task name to task id mapping. Defaults to None.
node_name_to_ids (Optional[Dict[str, str]], optional): The node name to node
"""
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx: Optional[TaskContext] = None
self._share_data: Dict[str, Any] = {}
self._share_data: Dict[str, Any] = share_data
self._node_to_outputs: Dict[str, TaskContext] = node_to_outputs
self._node_name_to_ids: Dict[str, str] = node_name_to_ids

Expand Down Expand Up @@ -530,6 +528,7 @@ async def get_from_share_data(self, key: str) -> Any:
Returns:
Any: The share data, you can cast it to the real type
"""
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
return self._share_data.get(key)

async def save_to_share_data(
Expand All @@ -545,6 +544,7 @@ async def save_to_share_data(
"""
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
self._share_data[key] = data

async def get_task_share_data(self, task_name: str, key: str) -> Any:
Expand Down
20 changes: 16 additions & 4 deletions dbgpt/core/awel/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This runner will run the workflow in the current process.
"""
import logging
from typing import Dict, List, Optional, Set, cast
from typing import Any, Dict, List, Optional, Set, cast

from dbgpt.component import SystemApp

Expand All @@ -20,6 +20,10 @@
class DefaultWorkflowRunner(WorkflowRunner):
"""The default workflow runner."""

def __init__(self):
"""Init the default workflow runner."""
self._running_dag_ctx: Dict[str, DAGContext] = {}

async def execute_workflow(
self,
node: BaseOperator,
Expand All @@ -44,15 +48,22 @@ async def execute_workflow(
if not exist_dag_ctx:
# Create DAG context
node_outputs: Dict[str, TaskContext] = {}
share_data: Dict[str, Any] = {}
else:
# Share node output with exist dag context
node_outputs = exist_dag_ctx._node_to_outputs
share_data = exist_dag_ctx._share_data
dag_ctx = DAGContext(
streaming_call=streaming_call,
node_to_outputs=node_outputs,
share_data=share_data,
streaming_call=streaming_call,
node_name_to_ids=job_manager._node_name_to_ids,
)
logger.info(f"Begin run workflow from end operator, id: {node.node_id}")
if node.dag:
self._running_dag_ctx[node.dag.dag_id] = dag_ctx
logger.info(
f"Begin run workflow from end operator, id: {node.node_id}, runner: {self}"
)
logger.debug(f"Node id {node.node_id}, call_data: {call_data}")
skip_node_ids: Set[str] = set()
system_app: Optional[SystemApp] = DAGVar.get_current_system_app()
Expand All @@ -64,7 +75,8 @@ async def execute_workflow(
if not streaming_call and node.dag:
# streaming call not work for dag end
await node.dag._after_dag_end()

if node.dag:
del self._running_dag_ctx[node.dag.dag_id]
return dag_ctx

async def _execute_node(
Expand Down
119 changes: 74 additions & 45 deletions dbgpt/core/awel/trigger/http_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@
from starlette.requests import Request

RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
StreamingPredictFunc = Callable[[Union[Request, BaseModel, str, None]], bool]

logger = logging.getLogger(__name__)


class AWELHttpError(RuntimeError):
"""AWEL Http Error."""

def __init__(self, msg: str, code: Optional[str] = None):
"""Init the AWELHttpError."""
super().__init__(msg)
self.msg = msg
self.code = code


class HttpTrigger(Trigger):
"""Http trigger for AWEL.
Expand Down Expand Up @@ -65,29 +75,74 @@ def mount_to_router(self, router: "APIRouter") -> None:
Args:
router (APIRouter): The router to mount the trigger.
"""
from fastapi import Depends
from inspect import Parameter, Signature
from typing import get_type_hints

from starlette.requests import Request

methods = [self._methods] if isinstance(self._methods, str) else self._methods
is_query_method = (
all(method in ["GET", "DELETE"] for method in methods) if methods else True
)

async def _trigger_dag_func(body: Union[Request, BaseModel, str, None]):
streaming_response = self._streaming_response
if self._streaming_predict_func:
streaming_response = self._streaming_predict_func(body)
dag = self.dag
if not dag:
raise AWELHttpError("DAG is not set")
return await _trigger_dag(
body,
dag,
streaming_response,
self._response_headers,
self._response_media_type,
)

def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)

async def route_function(body=Depends(_request_body_dependency)):
streaming_response = self._streaming_response
if self._streaming_predict_func:
streaming_response = self._streaming_predict_func(body)
return await _trigger_dag(
body,
self.dag,
streaming_response,
self._response_headers,
self._response_media_type,
)

route_function.__name__ = name
return route_function
async def route_function_request(request: Request):
return await _trigger_dag_func(request)

async def route_function_none():
return await _trigger_dag_func(None)

route_function_request.__name__ = name
route_function_none.__name__ = name

if not req_body_cls:
return route_function_none
if req_body_cls == Request:
return route_function_request

if is_query_method:
if req_body_cls == str:
raise AWELHttpError(f"Query methods {methods} not support str type")

async def route_function_get(**kwargs):
body = req_body_cls(**kwargs)
return await _trigger_dag_func(body)

parameters = [
Parameter(
name=field_name,
kind=Parameter.KEYWORD_ONLY,
default=Parameter.empty,
annotation=field.outer_type_,
)
for field_name, field in req_body_cls.__fields__.items()
]
route_function_get.__signature__ = Signature(parameters) # type: ignore
route_function_get.__annotations__ = get_type_hints(req_body_cls)
route_function_get.__name__ = name
return route_function_get
else:

async def route_function(body: req_body_cls): # type: ignore
return await _trigger_dag_func(body)

route_function.__name__ = name
return route_function

function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
request_model = (
Expand All @@ -111,32 +166,6 @@ async def route_function(body=Depends(_request_body_dependency)):
)(dynamic_route_function)


async def _parse_request_body(
request: "Request", request_body_cls: Optional["RequestBody"]
):
from starlette.requests import Request

if not request_body_cls:
return None
if request_body_cls == Request:
return request
if request.method == "POST":
if request_body_cls == str:
bytes_body = await request.body()
str_body = bytes_body.decode("utf-8")
return str_body
elif issubclass(request_body_cls, BaseModel):
json_data = await request.json()
return request_body_cls(**json_data)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")
elif request.method == "GET":
if issubclass(request_body_cls, BaseModel):
return request_body_cls(**request.query_params)
else:
raise ValueError(f"Invalid request body cls: {request_body_cls}")


async def _trigger_dag(
body: Any,
dag: DAG,
Expand Down

0 comments on commit 7ae762f

Please sign in to comment.