Skip to content

Commit

Permalink
feat(awel): AWEL supports http trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Nov 21, 2023
1 parent 35c1221 commit e67d62a
Show file tree
Hide file tree
Showing 20 changed files with 655 additions and 12 deletions.
54 changes: 54 additions & 0 deletions examples/awel/simple_chat_dag_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""AWEL: Simple chat dag example
Example:
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"user_input": "hello"
}'
"""
from typing import Dict
from pydantic import BaseModel, Field

from pilot.awel import DAG, HttpTrigger, MapOperator
from pilot.scene.base_message import ModelMessage
from pilot.model.base import ModelOutput
from pilot.model.operator.model_operator import ModelOperator


class TriggerReqBody(BaseModel):
model: str = Field(..., description="Model name")
user_input: str = Field(..., description="User input")


class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
def __init__(self, **kwargs):
super().__init__(**kwargs)

async def map(self, input_value: TriggerReqBody) -> Dict:
hist = []
hist.append(ModelMessage.build_human_message(input_value.user_input))
hist = list(h.dict() for h in hist)
params = {
"prompt": input_value.user_input,
"messages": hist,
"model": input_value.model,
"echo": False,
}
print(f"Receive input value: {input_value}")
return params


with DAG("dbgpt_awel_simple_dag_example") as dag:
# Receive http request and trigger dag to run.
trigger = HttpTrigger(
"/examples/simple_chat", methods="POST", request_body=TriggerReqBody
)
request_handle_task = RequestHandleOperator()
model_task = ModelOperator()
# type(out) == ModelOutput
model_parse_task = MapOperator(lambda out: out.to_dict())
trigger >> request_handle_task >> model_task >> model_parse_task
32 changes: 32 additions & 0 deletions examples/awel/simple_dag_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""AWEL: Simple dag example
Example:
.. code-block:: shell
curl -X GET http://127.0.0.1:5000/api/v1/awel/trigger/examples/hello\?name\=zhangsan
"""
from pydantic import BaseModel, Field

from pilot.awel import DAG, HttpTrigger, MapOperator


class TriggerReqBody(BaseModel):
name: str = Field(..., description="User name")
age: int = Field(18, description="User age")


class RequestHandleOperator(MapOperator[TriggerReqBody, str]):
def __init__(self, **kwargs):
super().__init__(**kwargs)

async def map(self, input_value: TriggerReqBody) -> str:
print(f"Receive input value: {input_value}")
return f"Hello, {input_value.name}, your age is {input_value.age}"


with DAG("simple_dag_example") as dag:
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
map_node = RequestHandleOperator()
trigger >> map_node
31 changes: 29 additions & 2 deletions pilot/awel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
"""Agentic Workflow Expression Language (AWEL)"""
"""Agentic Workflow Expression Language (AWEL)
Note:
AWEL is still an experimental feature and only opens the lowest level API.
The stability of this API cannot be guaranteed at present.
"""

from pilot.component import SystemApp

from .dag.base import DAGContext, DAG

from .operator.base import BaseOperator, WorkflowRunner, initialize_awel
from .operator.base import BaseOperator, WorkflowRunner
from .operator.common_operator import (
JoinOperator,
ReduceStreamOperator,
Expand All @@ -28,6 +37,7 @@
SimpleStreamTaskOutput,
_is_async_iterator,
)
from .trigger.http_trigger import HttpTrigger
from .runner.local_runner import DefaultWorkflowRunner

__all__ = [
Expand Down Expand Up @@ -57,4 +67,21 @@
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"HttpTrigger",
]


def initialize_awel(system_app: SystemApp, dag_filepath: str):
from .dag.dag_manager import DAGManager
from .dag.base import DAGVar
from .trigger.trigger_manager import DefaultTriggerManager
from .operator.base import initialize_runner

DAGVar.set_current_system_app(system_app)

system_app.register(DefaultTriggerManager)
dag_manager = DAGManager(system_app, dag_filepath)
system_app.register_instance(dag_manager)
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()
7 changes: 7 additions & 0 deletions pilot/awel/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from abc import ABC, abstractmethod


class Trigger(ABC):
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""
88 changes: 85 additions & 3 deletions pilot/awel/dag/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Sequence, Union, Any
from typing import Optional, Dict, List, Sequence, Union, Any, Set
import uuid
import contextvars
import threading
import asyncio
import logging
from collections import deque
from functools import cache

from pilot.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext

logger = logging.getLogger(__name__)

DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]]


Expand Down Expand Up @@ -96,6 +101,7 @@ def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin":
class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None

@classmethod
def enter_dag(cls, dag) -> None:
Expand Down Expand Up @@ -138,18 +144,38 @@ def get_current_dag(cls) -> Optional["DAG"]:
return cls._thread_local.current_dag_stack[-1]
return None

@classmethod
def get_current_system_app(cls) -> SystemApp:
if not cls._system_app:
raise RuntimeError("System APP not set for DAGVar")
return cls._system_app

@classmethod
def set_current_system_app(cls, system_app: SystemApp) -> None:
if cls._system_app:
logger.warn("System APP has already set, nothing to do")
else:
cls._system_app = system_app


class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""

def __init__(
self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None
self,
dag: Optional["DAG"] = None,
node_id: str = None,
node_name: str = None,
system_app: SystemApp = None,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
self._downstream: List["DAGNode"] = []
self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag()
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
Expand All @@ -159,6 +185,10 @@ def __init__(
def node_id(self) -> str:
return self._node_id

@property
def system_app(self) -> SystemApp:
return self._system_app

def set_node_id(self, node_id: str) -> None:
self._node_id = node_id

Expand All @@ -178,7 +208,7 @@ def node_name(self) -> str:
return self._node_name

@property
def dag(self) -> "DAGNode":
def dag(self) -> "DAG":
return self._dag

def set_upstream(self, nodes: DependencyType) -> "DAGNode":
Expand Down Expand Up @@ -254,17 +284,69 @@ class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None

def _append_node(self, node: DAGNode) -> None:
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
self._leaf_nodes = None

def _new_node_id(self) -> str:
return str(uuid.uuid4())

@property
def dag_id(self) -> str:
return self._dag_id

def _build(self) -> None:
from ..operator.common_operator import TriggerOperator

nodes = set()
for _, node in self.node_map.items():
nodes = nodes.union(_get_nodes(node))
self._root_nodes = list(set(filter(lambda x: not x.upstream, nodes)))
self._leaf_nodes = list(set(filter(lambda x: not x.downstream, nodes)))
self._trigger_nodes = list(
set(filter(lambda x: isinstance(x, TriggerOperator), nodes))
)

@property
def root_nodes(self) -> List[DAGNode]:
if not self._root_nodes:
self._build()
return self._root_nodes

@property
def leaf_nodes(self) -> List[DAGNode]:
if not self._leaf_nodes:
self._build()
return self._leaf_nodes

@property
def trigger_nodes(self):
if not self._trigger_nodes:
self._build()
return self._trigger_nodes

def __enter__(self):
DAGVar.enter_dag(self)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
DAGVar.exit_dag()


def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode]:
nodes = set()
if not node:
return nodes
nodes.add(node)
stream_nodes = node.upstream if is_upstream else node.downstream
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes
42 changes: 42 additions & 0 deletions pilot/awel/dag/dag_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Dict, Optional
import logging
from pilot.component import BaseComponent, ComponentType, SystemApp
from .loader import DAGLoader, LocalFileDAGLoader
from .base import DAG

logger = logging.getLogger(__name__)


class DAGManager(BaseComponent):
name = ComponentType.AWEL_DAG_MANAGER

def __init__(self, system_app: SystemApp, dag_filepath: str):
super().__init__(system_app)
self.dag_loader = LocalFileDAGLoader(dag_filepath)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}

def init_app(self, system_app: SystemApp):
self.system_app = system_app

def load_dags(self):
dags = self.dag_loader.load_dags()
triggers = []
for dag in dags:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Load DAG error, DAG ID {dag_id} has already exist")
triggers += dag.trigger_nodes
from ..trigger.trigger_manager import DefaultTriggerManager

trigger_manager: DefaultTriggerManager = self.system_app.get_component(
ComponentType.AWEL_TRIGGER_MANAGER,
DefaultTriggerManager,
default_component=None,
)
if trigger_manager:
for trigger in triggers:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
else:
logger.warn("No trigger manager, not register dag trigger")
Loading

0 comments on commit e67d62a

Please sign in to comment.