Skip to content

Commit

Permalink
docs: New AWEL tutorial (#1245)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Mar 4, 2024
1 parent 7a38edc commit 3c93fe5
Show file tree
Hide file tree
Showing 42 changed files with 17,658 additions and 11,112 deletions.
45 changes: 39 additions & 6 deletions dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,16 @@
UnstreamifyAbsOperator,
)
from .runner.local_runner import DefaultWorkflowRunner
from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState
from .task.base import (
InputContext,
InputSource,
TaskContext,
TaskOutput,
TaskState,
is_empty_data,
)
from .task.task_impl import (
BaseInputSource,
DefaultInputContext,
DefaultTaskContext,
SimpleCallDataInputSource,
Expand All @@ -40,6 +48,7 @@
SimpleTaskOutput,
_is_async_iterator,
)
from .trigger.base import Trigger
from .trigger.http_trigger import (
CommonLLMHttpRequestBody,
CommonLLMHTTPRequestContext,
Expand Down Expand Up @@ -73,12 +82,14 @@
"BranchFunc",
"WorkflowRunner",
"TaskState",
"is_empty_data",
"TaskOutput",
"TaskContext",
"InputContext",
"InputSource",
"DefaultWorkflowRunner",
"SimpleInputSource",
"BaseInputSource",
"SimpleCallDataInputSource",
"DefaultTaskContext",
"DefaultInputContext",
Expand All @@ -87,6 +98,7 @@
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"Trigger",
"HttpTrigger",
"CommonLLMHTTPRequestContext",
"CommonLLMHttpResponseBody",
Expand Down Expand Up @@ -136,9 +148,6 @@ def setup_dev_environment(
Defaults to True. If True, the DAG graph will be saved to a file and open
it automatically.
"""
import uvicorn
from fastapi import FastAPI

from dbgpt.component import SystemApp
from dbgpt.util.utils import setup_logging

Expand All @@ -148,7 +157,13 @@ def setup_dev_environment(
logger_filename = "dbgpt_awel_dev.log"
setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename)

app = FastAPI()
start_http = _check_has_http_trigger(dags)
if start_http:
from fastapi import FastAPI

app = FastAPI()
else:
app = None
system_app = SystemApp(app)
DAGVar.set_current_system_app(system_app)
trigger_manager = DefaultTriggerManager()
Expand All @@ -169,6 +184,24 @@ def setup_dev_environment(
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger, system_app)
trigger_manager.after_register()
if trigger_manager.keep_running():
if start_http and trigger_manager.keep_running() and app:
import uvicorn

# Should keep running
uvicorn.run(app, host=host, port=port)


def _check_has_http_trigger(dags: List[DAG]) -> bool:
"""Check whether has http trigger.
Args:
dags (List[DAG]): The dags.
Returns:
bool: Whether has http trigger.
"""
for dag in dags:
for trigger in dag.trigger_nodes:
if isinstance(trigger, HttpTrigger):
return True
return False
10 changes: 0 additions & 10 deletions dbgpt/core/awel/base.py

This file was deleted.

37 changes: 35 additions & 2 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def __init__(
node_id = self._dag._new_node_id()
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
if self._dag:
self._dag._append_node(self)

@property
def node_id(self) -> str:
Expand Down Expand Up @@ -421,7 +423,7 @@ def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> Non
def __repr__(self):
"""Return the representation of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
if self.node_id and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
if self.node_id:
return f"{cls_name}(node_id={self.node_id})"
Expand All @@ -430,6 +432,19 @@ def __repr__(self):
else:
return f"{cls_name}"

@property
def graph_str(self):
"""Return the graph string of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_id and self.node_name:
return f"{self.node_id}({cls_name},{self.node_name})"
if self.node_id:
return f"{self.node_id}({cls_name})"
if self.node_name:
return f"{self.node_name}_{cls_name}({cls_name})"
else:
return f"{cls_name}"

def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()
Expand Down Expand Up @@ -798,12 +813,16 @@ def _handle_dag_nodes(
_handle_dag_nodes(is_down_to_up, level, node, func)


def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
def _visualize_dag(
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
) -> Optional[str]:
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
Expand All @@ -815,15 +834,20 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
return None

dot = Digraph(name=dag.dag_id)
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
# Record the added edges to avoid adding duplicate edges
added_edges = set()

def add_edges(node: DAGNode):
nonlocal mermaid_str
if node.downstream:
for downstream_node in node.downstream:
# Check if the edge has been added
if (str(node), str(downstream_node)) not in added_edges:
dot.edge(str(node), str(downstream_node))
mermaid_str += (
f" {node.graph_str} --> {downstream_node.graph_str};\n"
)
added_edges.add((str(node), str(downstream_node)))
add_edges(downstream_node)

Expand All @@ -839,4 +863,13 @@ def add_edges(node: DAGNode):

kwargs["directory"] = LOGDIR

# Generate Mermaid syntax file if requested
if generate_mermaid:
mermaid_filename = filename.replace(".gv", ".md")
with open(
f"{kwargs.get('directory', '')}/{mermaid_filename}", "w"
) as mermaid_file:
logger.info(f"Writing Mermaid syntax to {mermaid_filename}")
mermaid_file.write(mermaid_str)

return dot.render(filename, view=view, **kwargs)
14 changes: 7 additions & 7 deletions dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)

from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
from ..task.base import OUT, T, TaskOutput
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput

F = TypeVar("F", bound=FunctionType)

Expand Down Expand Up @@ -186,7 +186,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:

async def call(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
) -> OUT:
"""Execute the node and return the output.
Expand All @@ -200,7 +200,7 @@ async def call(
Returns:
OUT: The output of the node after execution.
"""
if call_data:
if call_data != EMPTY_DATA:
call_data = {"data": call_data}
out_ctx = await self._runner.execute_workflow(
self, call_data, exist_dag_ctx=dag_ctx
Expand All @@ -209,7 +209,7 @@ async def call(

def _blocking_call(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> OUT:
"""Execute the node and return the output.
Expand All @@ -232,7 +232,7 @@ def _blocking_call(

async def call_stream(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream.
Expand All @@ -247,7 +247,7 @@ async def call_stream(
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
if call_data:
if call_data != EMPTY_DATA:
call_data = {"data": call_data}
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
Expand All @@ -256,7 +256,7 @@ async def call_stream(

def _blocking_call_stream(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> Iterator[OUT]:
"""Execute the node and return the output as a stream.
Expand Down
13 changes: 2 additions & 11 deletions dbgpt/core/awel/operators/common_operator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
"""Common operators of AWEL."""
import asyncio
import logging
from typing import (
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
Union,
)
from typing import Awaitable, Callable, Dict, Generic, List, Optional, Union

from ..dag.base import DAGContext
from ..task.base import (
Expand Down Expand Up @@ -106,7 +97,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx.set_task_output(reduce_output)
return reduce_output

async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
async def reduce(self, a: IN, b: IN) -> OUT:
"""Reduce the input stream to a single value."""
raise NotImplementedError

Expand Down
18 changes: 13 additions & 5 deletions dbgpt/core/awel/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@


class _EMPTY_DATA_TYPE:
"""A special type to represent empty data."""

def __init__(self, name: str = "EMPTY_DATA"):
self.name = name

def __bool__(self):
return False

def __str__(self):
return f"EmptyData({self.name})"


EMPTY_DATA = _EMPTY_DATA_TYPE()
SKIP_DATA = _EMPTY_DATA_TYPE()
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
EMPTY_DATA = _EMPTY_DATA_TYPE("EMPTY_DATA")
SKIP_DATA = _EMPTY_DATA_TYPE("SKIP_DATA")
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE("PLACEHOLDER_DATA")


def is_empty_data(data: Any):
Expand All @@ -37,7 +45,7 @@ def is_empty_data(data: Any):


MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN, IN], OUT], Callable[[IN, IN], Awaitable[OUT]]]
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
Expand Down Expand Up @@ -341,7 +349,7 @@ async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
"""

@abstractmethod
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
async def reduce(self, reduce_func: ReduceFunc) -> "InputContext":
"""Apply a reducing function to the inputs.
Args:
Expand Down
9 changes: 6 additions & 3 deletions dbgpt/core/awel/task/task_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,8 @@ async def _apply_func(
if apply_type == "map":
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
reduce_func = cast(ReduceFunc, func)
result = out.task_output.reduce(reduce_func)
elif apply_type == "check_condition":
result = out.task_output.check_condition(func)
else:
Expand Down Expand Up @@ -541,14 +542,16 @@ async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
)
return DefaultInputContext([single_output])

async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
async def reduce(self, reduce_func: ReduceFunc) -> InputContext:
"""Apply a reduce function to all parent outputs."""
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply"
" reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
new_outputs, results = await self._apply_func(
reduce_func, apply_type="reduce" # type: ignore
)
for i, task_ctx in enumerate(new_outputs):
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])
Expand Down
Loading

0 comments on commit 3c93fe5

Please sign in to comment.