Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: New AWEL tutorial #1245

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,16 @@
UnstreamifyAbsOperator,
)
from .runner.local_runner import DefaultWorkflowRunner
from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState
from .task.base import (
InputContext,
InputSource,
TaskContext,
TaskOutput,
TaskState,
is_empty_data,
)
from .task.task_impl import (
BaseInputSource,
DefaultInputContext,
DefaultTaskContext,
SimpleCallDataInputSource,
Expand All @@ -40,6 +48,7 @@
SimpleTaskOutput,
_is_async_iterator,
)
from .trigger.base import Trigger
from .trigger.http_trigger import (
CommonLLMHttpRequestBody,
CommonLLMHTTPRequestContext,
Expand Down Expand Up @@ -73,12 +82,14 @@
"BranchFunc",
"WorkflowRunner",
"TaskState",
"is_empty_data",
"TaskOutput",
"TaskContext",
"InputContext",
"InputSource",
"DefaultWorkflowRunner",
"SimpleInputSource",
"BaseInputSource",
"SimpleCallDataInputSource",
"DefaultTaskContext",
"DefaultInputContext",
Expand All @@ -87,6 +98,7 @@
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"Trigger",
"HttpTrigger",
"CommonLLMHTTPRequestContext",
"CommonLLMHttpResponseBody",
Expand Down Expand Up @@ -136,9 +148,6 @@ def setup_dev_environment(
Defaults to True. If True, the DAG graph will be saved to a file and open
it automatically.
"""
import uvicorn
from fastapi import FastAPI

from dbgpt.component import SystemApp
from dbgpt.util.utils import setup_logging

Expand All @@ -148,7 +157,13 @@ def setup_dev_environment(
logger_filename = "dbgpt_awel_dev.log"
setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename)

app = FastAPI()
start_http = _check_has_http_trigger(dags)
if start_http:
from fastapi import FastAPI

app = FastAPI()
else:
app = None
system_app = SystemApp(app)
DAGVar.set_current_system_app(system_app)
trigger_manager = DefaultTriggerManager()
Expand All @@ -169,6 +184,24 @@ def setup_dev_environment(
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger, system_app)
trigger_manager.after_register()
if trigger_manager.keep_running():
if start_http and trigger_manager.keep_running() and app:
import uvicorn

# Should keep running
uvicorn.run(app, host=host, port=port)


def _check_has_http_trigger(dags: List[DAG]) -> bool:
"""Check whether has http trigger.

Args:
dags (List[DAG]): The dags.

Returns:
bool: Whether has http trigger.
"""
for dag in dags:
for trigger in dag.trigger_nodes:
if isinstance(trigger, HttpTrigger):
return True
return False
10 changes: 0 additions & 10 deletions dbgpt/core/awel/base.py

This file was deleted.

37 changes: 35 additions & 2 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def __init__(
node_id = self._dag._new_node_id()
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
if self._dag:
self._dag._append_node(self)

@property
def node_id(self) -> str:
Expand Down Expand Up @@ -421,7 +423,7 @@ def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> Non
def __repr__(self):
"""Return the representation of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
if self.node_id and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
if self.node_id:
return f"{cls_name}(node_id={self.node_id})"
Expand All @@ -430,6 +432,19 @@ def __repr__(self):
else:
return f"{cls_name}"

@property
def graph_str(self):
"""Return the graph string of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_id and self.node_name:
return f"{self.node_id}({cls_name},{self.node_name})"
if self.node_id:
return f"{self.node_id}({cls_name})"
if self.node_name:
return f"{self.node_name}_{cls_name}({cls_name})"
else:
return f"{cls_name}"

def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()
Expand Down Expand Up @@ -798,12 +813,16 @@ def _handle_dag_nodes(
_handle_dag_nodes(is_down_to_up, level, node, func)


def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
def _visualize_dag(
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
) -> Optional[str]:
"""Visualize the DAG.

Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
Defaults to True.

Returns:
Optional[str]: The filename of the DAG graph
Expand All @@ -815,15 +834,20 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
return None

dot = Digraph(name=dag.dag_id)
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
# Record the added edges to avoid adding duplicate edges
added_edges = set()

def add_edges(node: DAGNode):
nonlocal mermaid_str
if node.downstream:
for downstream_node in node.downstream:
# Check if the edge has been added
if (str(node), str(downstream_node)) not in added_edges:
dot.edge(str(node), str(downstream_node))
mermaid_str += (
f" {node.graph_str} --> {downstream_node.graph_str};\n"
)
added_edges.add((str(node), str(downstream_node)))
add_edges(downstream_node)

Expand All @@ -839,4 +863,13 @@ def add_edges(node: DAGNode):

kwargs["directory"] = LOGDIR

# Generate Mermaid syntax file if requested
if generate_mermaid:
mermaid_filename = filename.replace(".gv", ".md")
with open(
f"{kwargs.get('directory', '')}/{mermaid_filename}", "w"
) as mermaid_file:
logger.info(f"Writing Mermaid syntax to {mermaid_filename}")
mermaid_file.write(mermaid_str)

return dot.render(filename, view=view, **kwargs)
14 changes: 7 additions & 7 deletions dbgpt/core/awel/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)

from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
from ..task.base import OUT, T, TaskOutput
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput

F = TypeVar("F", bound=FunctionType)

Expand Down Expand Up @@ -186,7 +186,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:

async def call(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
) -> OUT:
"""Execute the node and return the output.
Expand All @@ -200,7 +200,7 @@ async def call(
Returns:
OUT: The output of the node after execution.
"""
if call_data:
if call_data != EMPTY_DATA:
call_data = {"data": call_data}
out_ctx = await self._runner.execute_workflow(
self, call_data, exist_dag_ctx=dag_ctx
Expand All @@ -209,7 +209,7 @@ async def call(

def _blocking_call(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> OUT:
"""Execute the node and return the output.
Expand All @@ -232,7 +232,7 @@ def _blocking_call(

async def call_stream(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream.
Expand All @@ -247,7 +247,7 @@ async def call_stream(
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
if call_data:
if call_data != EMPTY_DATA:
call_data = {"data": call_data}
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
Expand All @@ -256,7 +256,7 @@ async def call_stream(

def _blocking_call_stream(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> Iterator[OUT]:
"""Execute the node and return the output as a stream.
Expand Down
13 changes: 2 additions & 11 deletions dbgpt/core/awel/operators/common_operator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
"""Common operators of AWEL."""
import asyncio
import logging
from typing import (
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
Union,
)
from typing import Awaitable, Callable, Dict, Generic, List, Optional, Union

from ..dag.base import DAGContext
from ..task.base import (
Expand Down Expand Up @@ -106,7 +97,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
curr_task_ctx.set_task_output(reduce_output)
return reduce_output

async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
async def reduce(self, a: IN, b: IN) -> OUT:
"""Reduce the input stream to a single value."""
raise NotImplementedError

Expand Down
18 changes: 13 additions & 5 deletions dbgpt/core/awel/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@


class _EMPTY_DATA_TYPE:
"""A special type to represent empty data."""

def __init__(self, name: str = "EMPTY_DATA"):
self.name = name

def __bool__(self):
return False

def __str__(self):
return f"EmptyData({self.name})"


EMPTY_DATA = _EMPTY_DATA_TYPE()
SKIP_DATA = _EMPTY_DATA_TYPE()
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
EMPTY_DATA = _EMPTY_DATA_TYPE("EMPTY_DATA")
SKIP_DATA = _EMPTY_DATA_TYPE("SKIP_DATA")
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE("PLACEHOLDER_DATA")


def is_empty_data(data: Any):
Expand All @@ -37,7 +45,7 @@ def is_empty_data(data: Any):


MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN, IN], OUT], Callable[[IN, IN], Awaitable[OUT]]]
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
Expand Down Expand Up @@ -341,7 +349,7 @@ async def map_all(self, map_func: Callable[..., Any]) -> "InputContext":
"""

@abstractmethod
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
async def reduce(self, reduce_func: ReduceFunc) -> "InputContext":
"""Apply a reducing function to the inputs.

Args:
Expand Down
9 changes: 6 additions & 3 deletions dbgpt/core/awel/task/task_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,8 @@ async def _apply_func(
if apply_type == "map":
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
reduce_func = cast(ReduceFunc, func)
result = out.task_output.reduce(reduce_func)
elif apply_type == "check_condition":
result = out.task_output.check_condition(func)
else:
Expand Down Expand Up @@ -541,14 +542,16 @@ async def map_all(self, map_func: Callable[..., Any]) -> InputContext:
)
return DefaultInputContext([single_output])

async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
async def reduce(self, reduce_func: ReduceFunc) -> InputContext:
"""Apply a reduce function to all parent outputs."""
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply"
" reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
new_outputs, results = await self._apply_func(
reduce_func, apply_type="reduce" # type: ignore
)
for i, task_ctx in enumerate(new_outputs):
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])
Expand Down
Loading
Loading