Skip to content

Commit

Permalink
fix(core): Fix AWEL branch bug (#1640)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Jun 18, 2024
1 parent 49b56b4 commit ace169a
Show file tree
Hide file tree
Showing 32 changed files with 866 additions and 477 deletions.
6 changes: 2 additions & 4 deletions dbgpt/app/scene/operators/app_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from dbgpt.core.awel import (
DAG,
BaseOperator,
BranchJoinOperator,
InputOperator,
JoinOperator,
MapOperator,
SimpleCallDataInputSource,
)
Expand Down Expand Up @@ -195,9 +195,7 @@ def build_cached_chat_operator(
cache_task_name=cache_task_name,
)
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
join_task = JoinOperator(
combine_function=lambda model_out, cache_out: cache_out or model_out
)
join_task = BranchJoinOperator()

# Define the workflow structure using the >> operator
input_task >> cache_check_branch_task
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .operators.base import BaseOperator, WorkflowRunner
from .operators.common_operator import (
BranchFunc,
BranchJoinOperator,
BranchOperator,
BranchTaskType,
InputOperator,
Expand Down Expand Up @@ -78,6 +79,7 @@
"ReduceStreamOperator",
"TriggerOperator",
"MapOperator",
"BranchJoinOperator",
"BranchOperator",
"InputOperator",
"BranchFunc",
Expand Down
31 changes: 25 additions & 6 deletions dbgpt/core/awel/flow/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The mixin of DAGs."""

import abc
import dataclasses
import inspect
Expand Down Expand Up @@ -337,6 +338,9 @@ class Parameter(TypeMetadata, Serializable):
value: Optional[Any] = Field(
None, description="The value of the parameter(Saved in the dag file)"
)
alias: Optional[List[str]] = Field(
None, description="The alias of the parameter(Compatible with old version)"
)

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -398,6 +402,7 @@ def build_from(
description: Optional[str] = None,
options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None,
resource_type: ResourceType = ResourceType.INSTANCE,
alias: Optional[List[str]] = None,
):
"""Build the parameter from the type."""
type_name = type.__qualname__
Expand All @@ -419,6 +424,7 @@ def build_from(
placeholder=placeholder,
description=description or label,
options=options,
alias=alias,
)

@classmethod
Expand Down Expand Up @@ -452,7 +458,7 @@ def build_from_ui(cls, data: Dict) -> "Parameter":

def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
dict_value = model_to_dict(self, exclude={"options"})
dict_value = model_to_dict(self, exclude={"options", "alias"})
if not self.options:
dict_value["options"] = None
elif isinstance(self.options, BaseDynamicOptions):
Expand Down Expand Up @@ -677,9 +683,18 @@ def get_runnable_parameters(
for parameter in self.parameters
if not parameter.optional
}
current_parameters = {
parameter.name: parameter for parameter in self.parameters
}
current_parameters = {}
current_aliases_parameters = {}
for parameter in self.parameters:
current_parameters[parameter.name] = parameter
if parameter.alias:
for alias in parameter.alias:
if alias in current_aliases_parameters:
raise FlowMetadataException(
f"Alias {alias} already exists in the metadata."
)
current_aliases_parameters[alias] = parameter

if len(view_required_parameters) < len(current_required_parameters):
# TODO, skip the optional parameters.
raise FlowParameterMetadataException(
Expand All @@ -691,12 +706,16 @@ def get_runnable_parameters(
)
for view_param in view_parameters:
view_param_key = view_param.name
if view_param_key not in current_parameters:
if view_param_key in current_parameters:
current_parameter = current_parameters[view_param_key]
elif view_param_key in current_aliases_parameters:
current_parameter = current_aliases_parameters[view_param_key]
else:
raise FlowParameterMetadataException(
f"Parameter {view_param_key} not in the metadata."
)
runnable_parameters.update(
current_parameters[view_param_key].to_runnable_parameter(
current_parameter.to_runnable_parameter(
view_param.get_typed_value(), resources, key_to_resource_instance
)
)
Expand Down
44 changes: 39 additions & 5 deletions dbgpt/core/awel/flow/compat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
"""Compatibility mapping for flow classes."""

from dataclasses import dataclass
from typing import Dict, Optional

_COMPAT_FLOW_MAPPING: Dict[str, str] = {}

@dataclass
class _RegisterItem:
"""Register item for compatibility mapping."""

old_module: str
new_module: str
old_name: str
new_name: Optional[str] = None
after: Optional[str] = None

def old_cls_key(self) -> str:
"""Get the old class key."""
return f"{self.old_module}.{self.old_name}"

def new_cls_key(self) -> str:
"""Get the new class key."""
return f"{self.new_module}.{self.new_name}"


_COMPAT_FLOW_MAPPING: Dict[str, _RegisterItem] = {}


_OLD_AGENT_RESOURCE_MODULE_1 = "dbgpt.serve.agent.team.layout.agent_operator_resource"
Expand All @@ -11,17 +32,24 @@


def _register(
old_module: str, new_module: str, old_name: str, new_name: Optional[str] = None
old_module: str,
new_module: str,
old_name: str,
new_name: Optional[str] = None,
after_version: Optional[str] = None,
):
if not new_name:
new_name = old_name
_COMPAT_FLOW_MAPPING[f"{old_module}.{old_name}"] = f"{new_module}.{new_name}"
item = _RegisterItem(old_module, new_module, old_name, new_name, after_version)
_COMPAT_FLOW_MAPPING[item.old_cls_key()] = item


def get_new_class_name(old_class_name: str) -> Optional[str]:
"""Get the new class name for the old class name."""
new_cls_name = _COMPAT_FLOW_MAPPING.get(old_class_name, None)
return new_cls_name
if old_class_name not in _COMPAT_FLOW_MAPPING:
return None
item = _COMPAT_FLOW_MAPPING[old_class_name]
return item.new_cls_key()


_register(
Expand Down Expand Up @@ -54,3 +82,9 @@ def get_new_class_name(old_class_name: str) -> Optional[str]:
_register(
_OLD_AGENT_RESOURCE_MODULE_2, _NEW_AGENT_RESOURCE_MODULE, "AWELAgent", "AWELAgent"
)
_register(
"dbgpt.storage.vector_store.connector",
"dbgpt.serve.rag.connector",
"VectorStoreConnector",
after_version="v0.5.8",
)
8 changes: 0 additions & 8 deletions dbgpt/core/awel/flow/flow_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,6 @@ def build(self, flow_panel: FlowPanel) -> DAG:
downstream = key_to_downstream.get(operator_key, [])
if not downstream:
raise ValueError("Branch operator should have downstream.")
if len(downstream) != len(view_metadata.parameters):
raise ValueError(
"Branch operator should have the same number of downstream as "
"parameters."
)
for i, param in enumerate(view_metadata.parameters):
downstream_key, _, _ = downstream[i]
param.value = key_to_operator_nodes[downstream_key].data.name

try:
runnable_params = metadata.get_runnable_parameters(
Expand Down
6 changes: 6 additions & 0 deletions dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
task_name: Optional[str] = None,
dag: Optional[DAG] = None,
runner: Optional[WorkflowRunner] = None,
can_skip_in_branch: bool = True,
**kwargs,
) -> None:
"""Create a BaseOperator with an optional workflow runner.
Expand All @@ -157,6 +158,7 @@ def __init__(

self._runner: WorkflowRunner = runner
self._dag_ctx: Optional[DAGContext] = None
self._can_skip_in_branch = can_skip_in_branch

@property
def current_dag_context(self) -> DAGContext:
Expand Down Expand Up @@ -321,6 +323,10 @@ def current_event_loop_task_id(self) -> int:
"""Get the current event loop task id."""
return id(asyncio.current_task())

def can_skip_in_branch(self) -> bool:
"""Check if the operator can be skipped in the branch."""
return self._can_skip_in_branch


def initialize_runner(runner: WorkflowRunner):
"""Initialize the default runner."""
Expand Down
40 changes: 38 additions & 2 deletions dbgpt/core/awel/operators/common_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ReduceFunc,
TaskContext,
TaskOutput,
is_empty_data,
)
from .base import BaseOperator

Expand All @@ -28,13 +29,16 @@ class JoinOperator(BaseOperator, Generic[OUT]):
This node type is useful for combining the outputs of upstream nodes.
"""

def __init__(self, combine_function: JoinFunc, **kwargs):
def __init__(
self, combine_function: JoinFunc, can_skip_in_branch: bool = True, **kwargs
):
"""Create a JoinDAGNode with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
can_skip_in_branch(bool): Whether the node can be skipped in a branch.
"""
super().__init__(**kwargs)
super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs)
if not callable(combine_function):
raise ValueError("combine_function must be callable")
self.combine_function = combine_function
Expand All @@ -57,6 +61,12 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx.set_task_output(join_output)
return join_output

async def _return_first_non_empty(self, *inputs):
for data in inputs:
if not is_empty_data(data):
return data
raise ValueError("All inputs are empty")


class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
"""Operator that reduces inputs using a custom reduce function."""
Expand Down Expand Up @@ -287,6 +297,32 @@ async def branches(self) -> Dict[BranchFunc[IN], BranchTaskType]:
raise NotImplementedError


class BranchJoinOperator(JoinOperator, Generic[OUT]):
"""Operator that joins inputs using a custom combine function.
This node type is useful for combining the outputs of upstream nodes.
"""

def __init__(
self,
combine_function: Optional[JoinFunc] = None,
can_skip_in_branch: bool = False,
**kwargs,
):
"""Create a JoinDAGNode with a combine function.
Args:
combine_function: A function that defines how to combine inputs.
can_skip_in_branch(bool): Whether the node can be skipped in a branch(
default True).
"""
super().__init__(
combine_function=combine_function or self._return_first_non_empty,
can_skip_in_branch=can_skip_in_branch,
**kwargs,
)


class InputOperator(BaseOperator, Generic[OUT]):
"""Operator node that reads data from an input source."""

Expand Down
9 changes: 5 additions & 4 deletions dbgpt/core/awel/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This runner will run the workflow in the current process.
"""

import asyncio
import logging
import traceback
Expand All @@ -11,7 +12,7 @@

from ..dag.base import DAGContext, DAGVar
from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner
from ..operators.common_operator import BranchOperator, JoinOperator
from ..operators.common_operator import BranchOperator
from ..task.base import SKIP_DATA, TaskContext, TaskState
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
from .job_manager import JobManager
Expand Down Expand Up @@ -184,14 +185,14 @@ def _skip_current_downstream_by_node_name(
return
for child in branch_node.downstream:
child = cast(BaseOperator, child)
if child.node_name in skip_nodes:
if child.node_name in skip_nodes or child.node_id in skip_node_ids:
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
_skip_downstream_by_id(child, skip_node_ids)


def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
if isinstance(node, JoinOperator):
# Not skip join node
if not node.can_skip_in_branch():
# Current node can not skip, so skip its downstream
return
skip_node_ids.add(node.node_id)
for child in node.downstream:
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/core/awel/tests/test_run_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def join_func(o1, o2) -> int:
even_node = MapOperator(
lambda x: 888, task_id="even_node", task_name="even_node_name"
)
join_node = JoinOperator(join_func)
join_node = JoinOperator(join_func, can_skip_in_branch=False)
branch_node = BranchOperator(
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
)
Expand Down
Loading

0 comments on commit ace169a

Please sign in to comment.