Skip to content

Commit

Permalink
feat(awel): New AWEL RAG example
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Nov 21, 2023
1 parent e67d62a commit 1801138
Show file tree
Hide file tree
Showing 16 changed files with 548 additions and 179 deletions.
70 changes: 70 additions & 0 deletions examples/awel/simple_rag_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""AWEL: Simple rag example
Example:
.. code-block:: shell
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \
-H "Content-Type: application/json" -d '{
"conv_uid": "36f0e992-8825-11ee-8638-0242ac150003",
"model_name": "proxyllm",
"chat_mode": "chat_knowledge",
"user_input": "What is DB-GPT?",
"select_param": "default"
}'
"""

from pilot.awel import HttpTrigger, DAG, MapOperator
from pilot.scene.operator._experimental import (
ChatContext,
PromptManagerOperator,
ChatHistoryStorageOperator,
ChatHistoryOperator,
EmbeddingEngingOperator,
BaseChatOperator,
)
from pilot.scene.base import ChatScene
from pilot.openapi.api_view_model import ConversationVo
from pilot.model.base import ModelOutput
from pilot.model.operator.model_operator import ModelOperator


class RequestParseOperator(MapOperator[ConversationVo, ChatContext]):
def __init__(self, **kwargs):
super().__init__(**kwargs)

async def map(self, input_value: ConversationVo) -> ChatContext:
return ChatContext(
current_user_input=input_value.user_input,
model_name=input_value.model_name,
chat_session_id=input_value.conv_uid,
select_param=input_value.select_param,
chat_scene=ChatScene.ChatKnowledge,
)


with DAG("simple_rag_example") as dag:
trigger_task = HttpTrigger(
"/examples/simple_rag", methods="POST", request_body=ConversationVo
)
req_parse_task = RequestParseOperator()
prompt_task = PromptManagerOperator()
history_storage_task = ChatHistoryStorageOperator()
history_task = ChatHistoryOperator()
embedding_task = EmbeddingEngingOperator()
chat_task = BaseChatOperator()
model_task = ModelOperator()
output_parser_task = MapOperator(lambda out: out.to_dict()["text"])

(
trigger_task
>> req_parse_task
>> prompt_task
>> history_storage_task
>> history_task
>> embedding_task
>> chat_task
>> model_task
>> output_parser_task
)
18 changes: 15 additions & 3 deletions pilot/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from collections import deque
from functools import cache
from concurrent.futures import Executor

from pilot.component import SystemApp
from ..resource.base import ResourceGroup
Expand Down Expand Up @@ -102,6 +103,7 @@ class DAGVar:
_thread_local = threading.local()
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
_system_app: SystemApp = None
_executor: Executor = None

@classmethod
def enter_dag(cls, dag) -> None:
Expand Down Expand Up @@ -157,6 +159,14 @@ def set_current_system_app(cls, system_app: SystemApp) -> None:
else:
cls._system_app = system_app

@classmethod
def get_executor(cls) -> Executor:
return cls._executor

@classmethod
def set_executor(cls, executor: Executor) -> None:
cls._executor = executor


class DAGNode(DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
Expand All @@ -165,9 +175,10 @@ class DAGNode(DependencyMixin, ABC):
def __init__(
self,
dag: Optional["DAG"] = None,
node_id: str = None,
node_name: str = None,
system_app: SystemApp = None,
node_id: Optional[str] = None,
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
Expand All @@ -176,6 +187,7 @@ def __init__(
self._system_app: Optional[SystemApp] = (
system_app or DAGVar.get_current_system_app()
)
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
if not node_id and self._dag:
node_id = self._dag._new_node_id()
self._node_id: str = node_id
Expand Down
25 changes: 24 additions & 1 deletion pilot/awel/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
)
import functools
from inspect import signature
from pilot.component import SystemApp
from pilot.component import SystemApp, ComponentType
from pilot.utils.executor_utils import (
ExecutorFactory,
DefaultExecutorFactory,
blocking_func_to_async,
BlockingFunction,
)

from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
from ..task.base import (
Expand Down Expand Up @@ -71,6 +77,16 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
system_app: Optional[SystemApp] = (
kwargs.get("system_app") or DAGVar.get_current_system_app()
)
executor = kwargs.get("executor") or DAGVar.get_executor()
if not executor:
if system_app:
executor = system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
else:
executor = DefaultExecutorFactory().create()
DAGVar.set_executor(executor)

if not task_id and dag:
task_id = dag._new_node_id()
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
Expand All @@ -86,6 +102,8 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
kwargs["runner"] = runner
if not kwargs.get("system_app"):
kwargs["system_app"] = system_app
if not kwargs.get("executor"):
kwargs["executor"] = executor
real_obj = func(self, *args, **kwargs)
return real_obj

Expand Down Expand Up @@ -177,6 +195,11 @@ async def call_stream(
out_ctx = await self._runner.execute_workflow(self, call_data)
return out_ctx.current_task_context.task_output.output_stream

async def blocking_func_to_async(
self, func: BlockingFunction, *args, **kwargs
) -> Any:
return await blocking_func_to_async(self._executor, func, *args, **kwargs)


def initialize_runner(runner: WorkflowRunner):
global default_runner
Expand Down
4 changes: 2 additions & 2 deletions pilot/awel/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def _execute_node(
node_outputs[node.node_id] = task_ctx
return
try:
logger.info(
logger.debug(
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
)
await node._run(dag_ctx)
Expand All @@ -76,7 +76,7 @@ async def _execute_node(

if isinstance(node, BranchOperator):
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
logger.info(
logger.debug(
f"Current is branch operator, skip node names: {skip_nodes}"
)
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
Expand Down
2 changes: 1 addition & 1 deletion pilot/memory/chat_history/store_type/meta_db_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create(self, chat_mode, summary: str, user_name: str) -> None:
logger.error("init create conversation log error!" + str(e))

def append(self, once_message: OnceConversation) -> None:
logger.info(f"db history append: {once_message}")
logger.debug(f"db history append: {once_message}")
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
self.chat_seesion_id
)
Expand Down
4 changes: 1 addition & 3 deletions pilot/model/proxy/llms/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def _build_request(model: ProxyModel, params):
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
payloads["model"] = proxyllm_backend

logger.info(
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
)
logger.info(f"Send request to real model {proxyllm_backend}")
return history, payloads


Expand Down
Loading

0 comments on commit 1801138

Please sign in to comment.