Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(awel): Modify AWEL http trigger route function #817

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/awel/simple_chat_dag_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""AWEL: Simple chat dag example

DB-GPT will automatically load and execute the current file after startup.

Example:

.. code-block:: shell
Expand Down
2 changes: 2 additions & 0 deletions examples/awel/simple_dag_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""AWEL: Simple dag example

DB-GPT will automatically load and execute the current file after startup.

Example:

.. code-block:: shell
Expand Down
3 changes: 3 additions & 0 deletions examples/awel/simple_rag_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""AWEL: Simple rag example

DB-GPT will automatically load and execute the current file after startup.

Example:

.. code-block:: shell
Expand Down Expand Up @@ -49,6 +51,7 @@ async def map(self, input_value: ConversationVo) -> ChatContext:
"/examples/simple_rag", methods="POST", request_body=ConversationVo
)
req_parse_task = RequestParseOperator()
# TODO should register prompt template first
prompt_task = PromptManagerOperator()
history_storage_task = ChatHistoryStorageOperator()
history_task = ChatHistoryOperator()
Expand Down
6 changes: 1 addition & 5 deletions pilot/awel/trigger/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from abc import ABC, abstractmethod

from ..operator.base import BaseOperator
from ..operator.common_operator import TriggerOperator
from ..dag.base import DAGContext
from ..task.base import TaskOutput


class Trigger(TriggerOperator, ABC):
@abstractmethod
async def trigger(self, end_operator: "BaseOperator") -> None:
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""
80 changes: 50 additions & 30 deletions pilot/awel/trigger/http_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging

from .base import Trigger
from ..dag.base import DAG
from ..operator.base import BaseOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,46 +51,33 @@ async def trigger(self) -> None:

def mount_to_router(self, router: "APIRouter") -> None:
from fastapi import Depends
from fastapi.responses import StreamingResponse

methods = self._methods if isinstance(self._methods, list) else [self._methods]

def create_route_function(name):
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)

async def route_function(body: Any = Depends(_request_body_dependency)):
end_node = self.dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not self._streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = self._response_headers
media_type = (
self._response_media_type
if self._response_media_type
else "text/event-stream"
)
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)
async def route_function(body=Depends(_request_body_dependency)):
return await _trigger_dag(
body,
self.dag,
self._streaming_response,
self._response_headers,
self._response_media_type,
)

route_function.__name__ = name
return route_function

function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
dynamic_route_function = create_route_function(function_name)
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
request_model = (
self._req_body
if isinstance(self._req_body, type)
and issubclass(self._req_body, BaseModel)
else None
)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
)
Expand All @@ -115,3 +103,35 @@ async def _parse_request_body(
return request_body_cls(**request.query_params)
else:
return request


async def _trigger_dag(
body: Any,
dag: DAG,
streaming_response: Optional[bool] = False,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
) -> Any:
from fastapi.responses import StreamingResponse

end_node = dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = response_headers
media_type = response_media_type if response_media_type else "text/event-stream"
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)
7 changes: 1 addition & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,6 @@ def cache_requires():
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]


# def chat_scene():
# setup_spec.extras["chat"] = [
# ""
# ]


def default_requires():
"""
pip install "db-gpt[default]"
Expand All @@ -445,6 +439,7 @@ def default_requires():
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
setup_spec.extras["default"] += setup_spec.extras["torch"]
setup_spec.extras["default"] += setup_spec.extras["quantization"]
setup_spec.extras["default"] += setup_spec.extras["cache"]


def all_requires():
Expand Down