diff --git a/.env.template b/.env.template index e03650033..ba7f752db 100644 --- a/.env.template +++ b/.env.template @@ -55,6 +55,17 @@ QUANTIZE_8bit=True ## Model path # llama_cpp_model_path=/data/models/TheBloke/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q4_K_M.gguf +### LLM cache +## Enable Model cache +# MODEL_CACHE_ENABLE=True +## The storage type of model cache, now supports: memory, disk +# MODEL_CACHE_STORAGE_TYPE=disk +## The max cache data in memory, we always store cache data in memory fist for high speed. +# MODEL_CACHE_MAX_MEMORY_MB=256 +## The dir to save cache data, this configuration is only valid when MODEL_CACHE_STORAGE_TYPE=disk +## The default dir is pilot/data/model_cache +# MODEL_CACHE_STORAGE_DISK_DIR= + #*******************************************************************# #** EMBEDDING SETTINGS **# #*******************************************************************# diff --git a/pilot/awel/__init__.py b/pilot/awel/__init__.py new file mode 100644 index 000000000..6c5313b5d --- /dev/null +++ b/pilot/awel/__init__.py @@ -0,0 +1,60 @@ +"""Agentic Workflow Expression Language (AWEL)""" + +from .dag.base import DAGContext, DAG + +from .operator.base import BaseOperator, WorkflowRunner, initialize_awel +from .operator.common_operator import ( + JoinOperator, + ReduceStreamOperator, + MapOperator, + BranchOperator, + InputOperator, + BranchFunc, +) + +from .operator.stream_operator import ( + StreamifyAbsOperator, + UnstreamifyAbsOperator, + TransformStreamAbsOperator, +) + +from .task.base import TaskState, TaskOutput, TaskContext, InputContext, InputSource +from .task.task_impl import ( + SimpleInputSource, + SimpleCallDataInputSource, + DefaultTaskContext, + DefaultInputContext, + SimpleTaskOutput, + SimpleStreamTaskOutput, + _is_async_iterator, +) +from .runner.local_runner import DefaultWorkflowRunner + +__all__ = [ + "initialize_awel", + "DAGContext", + "DAG", + "BaseOperator", + "JoinOperator", + "ReduceStreamOperator", + "MapOperator", + "BranchOperator", + "InputOperator", + "BranchFunc", + "WorkflowRunner", + "TaskState", + "TaskOutput", + "TaskContext", + "InputContext", + "InputSource", + "DefaultWorkflowRunner", + "SimpleInputSource", + "SimpleCallDataInputSource", + "DefaultTaskContext", + "DefaultInputContext", + "SimpleTaskOutput", + "SimpleStreamTaskOutput", + "StreamifyAbsOperator", + "UnstreamifyAbsOperator", + "TransformStreamAbsOperator", +] diff --git a/pilot/awel/dag/__init__.py b/pilot/awel/dag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/awel/dag/base.py b/pilot/awel/dag/base.py new file mode 100644 index 000000000..a6ad08990 --- /dev/null +++ b/pilot/awel/dag/base.py @@ -0,0 +1,270 @@ +from abc import ABC, abstractmethod +from typing import Optional, Dict, List, Sequence, Union, Any +import uuid +import contextvars +import threading +import asyncio +from collections import deque + +from ..resource.base import ResourceGroup +from ..task.base import TaskContext + +DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]] + + +def _is_async_context(): + try: + loop = asyncio.get_running_loop() + return asyncio.current_task(loop=loop) is not None + except RuntimeError: + return False + + +class DependencyMixin(ABC): + @abstractmethod + def set_upstream(self, nodes: DependencyType) -> "DependencyMixin": + """Set one or more upstream nodes for this node. + + Args: + nodes (DependencyType): Upstream nodes to be set to current node. + + Returns: + DependencyMixin: Returns self to allow method chaining. + + Raises: + ValueError: If no upstream nodes are provided or if an argument is not a DependencyMixin. + """ + + @abstractmethod + def set_downstream(self, nodes: DependencyType) -> "DependencyMixin": + """Set one or more downstream nodes for this node. + + Args: + nodes (DependencyType): Downstream nodes to be set to current node. + + Returns: + DependencyMixin: Returns self to allow method chaining. + + Raises: + ValueError: If no downstream nodes are provided or if an argument is not a DependencyMixin. + """ + + def __lshift__(self, nodes: DependencyType) -> DependencyType: + """Implements self << nodes + + Example: + + .. code-block:: python + + # means node.set_upstream(input_node) + node << input_node + + # means node2.set_upstream([input_node]) + node2 << [input_node] + """ + self.set_upstream(nodes) + return nodes + + def __rshift__(self, nodes: DependencyType) -> DependencyType: + """Implements self >> nodes + + Example: + + .. code-block:: python + + # means node.set_downstream(next_node) + node >> next_node + + # means node2.set_downstream([next_node]) + node2 >> [next_node] + + """ + self.set_downstream(nodes) + return nodes + + def __rrshift__(self, nodes: DependencyType) -> "DependencyMixin": + """Implements [node] >> self""" + self.__lshift__(nodes) + return self + + def __rlshift__(self, nodes: DependencyType) -> "DependencyMixin": + """Implements [node] << self""" + self.__rshift__(nodes) + return self + + +class DAGVar: + _thread_local = threading.local() + _async_local = contextvars.ContextVar("current_dag_stack", default=deque()) + + @classmethod + def enter_dag(cls, dag) -> None: + is_async = _is_async_context() + if is_async: + stack = cls._async_local.get() + stack.append(dag) + cls._async_local.set(stack) + else: + if not hasattr(cls._thread_local, "current_dag_stack"): + cls._thread_local.current_dag_stack = deque() + cls._thread_local.current_dag_stack.append(dag) + + @classmethod + def exit_dag(cls) -> None: + is_async = _is_async_context() + if is_async: + stack = cls._async_local.get() + if stack: + stack.pop() + cls._async_local.set(stack) + else: + if ( + hasattr(cls._thread_local, "current_dag_stack") + and cls._thread_local.current_dag_stack + ): + cls._thread_local.current_dag_stack.pop() + + @classmethod + def get_current_dag(cls) -> Optional["DAG"]: + is_async = _is_async_context() + if is_async: + stack = cls._async_local.get() + return stack[-1] if stack else None + else: + if ( + hasattr(cls._thread_local, "current_dag_stack") + and cls._thread_local.current_dag_stack + ): + return cls._thread_local.current_dag_stack[-1] + return None + + +class DAGNode(DependencyMixin, ABC): + resource_group: Optional[ResourceGroup] = None + """The resource group of current DAGNode""" + + def __init__( + self, dag: Optional["DAG"] = None, node_id: str = None, node_name: str = None + ) -> None: + super().__init__() + self._upstream: List["DAGNode"] = [] + self._downstream: List["DAGNode"] = [] + self._dag: Optional["DAG"] = dag or DAGVar.get_current_dag() + if not node_id and self._dag: + node_id = self._dag._new_node_id() + self._node_id: str = node_id + self._node_name: str = node_name + + @property + def node_id(self) -> str: + return self._node_id + + def set_node_id(self, node_id: str) -> None: + self._node_id = node_id + + def __hash__(self) -> int: + if self.node_id: + return hash(self.node_id) + else: + return super().__hash__() + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, DAGNode): + return False + return self.node_id == other.node_id + + @property + def node_name(self) -> str: + return self._node_name + + @property + def dag(self) -> "DAGNode": + return self._dag + + def set_upstream(self, nodes: DependencyType) -> "DAGNode": + self.set_dependency(nodes) + + def set_downstream(self, nodes: DependencyType) -> "DAGNode": + self.set_dependency(nodes, is_upstream=False) + + @property + def upstream(self) -> List["DAGNode"]: + return self._upstream + + @property + def downstream(self) -> List["DAGNode"]: + return self._downstream + + def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> None: + if not isinstance(nodes, Sequence): + nodes = [nodes] + if not all(isinstance(node, DAGNode) for node in nodes): + raise ValueError( + "all nodes to set dependency to current node must be instance of 'DAGNode'" + ) + nodes: Sequence[DAGNode] = nodes + dags = set([node.dag for node in nodes if node.dag]) + if self.dag: + dags.add(self.dag) + if not dags: + raise ValueError("set dependency to current node must in a DAG context") + if len(dags) != 1: + raise ValueError( + "set dependency to current node just support in one DAG context" + ) + dag = dags.pop() + self._dag = dag + + dag._append_node(self) + for node in nodes: + if is_upstream and node not in self.upstream: + node._dag = dag + dag._append_node(node) + + self._upstream.append(node) + node._downstream.append(self) + elif node not in self._downstream: + node._dag = dag + dag._append_node(node) + + self._downstream.append(node) + node._upstream.append(self) + + +class DAGContext: + def __init__(self) -> None: + self._curr_task_ctx = None + self._share_data: Dict[str, Any] = {} + + @property + def current_task_context(self) -> TaskContext: + return self._curr_task_ctx + + def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None: + self._curr_task_ctx = _curr_task_ctx + + async def get_share_data(self, key: str) -> Any: + return self._share_data.get(key) + + async def save_to_share_data(self, key: str, data: Any) -> None: + self._share_data[key] = data + + +class DAG: + def __init__( + self, dag_id: str, resource_group: Optional[ResourceGroup] = None + ) -> None: + self.node_map: Dict[str, DAGNode] = {} + + def _append_node(self, node: DAGNode) -> None: + self.node_map[node.node_id] = node + + def _new_node_id(self) -> str: + return str(uuid.uuid4()) + + def __enter__(self): + DAGVar.enter_dag(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + DAGVar.exit_dag() diff --git a/pilot/awel/dag/tests/__init__.py b/pilot/awel/dag/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/awel/dag/tests/test_dag.py b/pilot/awel/dag/tests/test_dag.py new file mode 100644 index 000000000..c30530dc8 --- /dev/null +++ b/pilot/awel/dag/tests/test_dag.py @@ -0,0 +1,51 @@ +import pytest +import threading +import asyncio +from ..dag import DAG, DAGContext + + +def test_dag_context_sync(): + dag1 = DAG("dag1") + dag2 = DAG("dag2") + + with dag1: + assert DAGContext.get_current_dag() == dag1 + with dag2: + assert DAGContext.get_current_dag() == dag2 + assert DAGContext.get_current_dag() == dag1 + assert DAGContext.get_current_dag() is None + + +def test_dag_context_threading(): + def thread_function(dag): + DAGContext.enter_dag(dag) + assert DAGContext.get_current_dag() == dag + DAGContext.exit_dag() + + dag1 = DAG("dag1") + dag2 = DAG("dag2") + + thread1 = threading.Thread(target=thread_function, args=(dag1,)) + thread2 = threading.Thread(target=thread_function, args=(dag2,)) + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + assert DAGContext.get_current_dag() is None + + +@pytest.mark.asyncio +async def test_dag_context_async(): + async def async_function(dag): + DAGContext.enter_dag(dag) + assert DAGContext.get_current_dag() == dag + DAGContext.exit_dag() + + dag1 = DAG("dag1") + dag2 = DAG("dag2") + + await asyncio.gather(async_function(dag1), async_function(dag2)) + + assert DAGContext.get_current_dag() is None diff --git a/pilot/awel/operator/__init__.py b/pilot/awel/operator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/awel/operator/base.py b/pilot/awel/operator/base.py new file mode 100644 index 000000000..b6d1a4e14 --- /dev/null +++ b/pilot/awel/operator/base.py @@ -0,0 +1,177 @@ +from abc import ABC, abstractmethod, ABCMeta + +from types import FunctionType +from typing import ( + List, + Generic, + TypeVar, + AsyncIterator, + Union, + Any, + Dict, + Optional, + cast, +) +import functools +from inspect import signature + +from ..dag.base import DAGNode, DAGContext, DAGVar, DAG +from ..task.base import ( + TaskContext, + TaskOutput, + TaskState, + OUT, + T, + InputContext, + InputSource, +) + +F = TypeVar("F", bound=FunctionType) + +CALL_DATA = Union[Dict, Dict[str, Dict]] + + +class WorkflowRunner(ABC, Generic[T]): + """Abstract base class representing a runner for executing workflows in a DAG. + + This class defines the interface for executing workflows within the DAG, + handling the flow from one DAG node to another. + """ + + @abstractmethod + async def execute_workflow( + self, node: "BaseOperator", call_data: Optional[CALL_DATA] = None + ) -> DAGContext: + """Execute the workflow starting from a given operator. + + Args: + node (RunnableDAGNode): The starting node of the workflow to be executed. + call_data (CALL_DATA): The data pass to root operator node. + + Returns: + DAGContext: The context after executing the workflow, containing the final state and data. + """ + + +default_runner: WorkflowRunner = None + + +class BaseOperatorMeta(ABCMeta): + """Metaclass of BaseOperator.""" + + @classmethod + def _apply_defaults(cls, func: F) -> F: + sig_cache = signature(func) + + @functools.wraps(func) + def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: + dag: Optional[DAG] = kwargs.get("dag") or DAGVar.get_current_dag() + task_id: Optional[str] = kwargs.get("task_id") + if not task_id and dag: + task_id = dag._new_node_id() + runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner + # print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}") + # for arg in sig_cache.parameters: + # if arg not in kwargs: + # kwargs[arg] = default_args[arg] + if not kwargs.get("dag"): + kwargs["dag"] = dag + if not kwargs.get("task_id"): + kwargs["task_id"] = task_id + if not kwargs.get("runner"): + kwargs["runner"] = runner + real_obj = func(self, *args, **kwargs) + return real_obj + + return cast(T, apply_defaults) + + def __new__(cls, name, bases, namespace, **kwargs): + new_cls = super().__new__(cls, name, bases, namespace, **kwargs) + new_cls.__init__ = cls._apply_defaults(new_cls.__init__) + return new_cls + + +class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta): + """Abstract base class for operator nodes that can be executed within a workflow. + + This class extends DAGNode by adding execution capabilities. + """ + + def __init__( + self, + task_id: Optional[str] = None, + task_name: Optional[str] = None, + dag: Optional[DAG] = None, + runner: WorkflowRunner = None, + **kwargs, + ) -> None: + """Initializes a BaseOperator with an optional workflow runner. + + Args: + runner (WorkflowRunner, optional): The runner used to execute the workflow. Defaults to None. + """ + super().__init__(node_id=task_id, node_name=task_name, dag=dag, **kwargs) + if not runner: + from pilot.awel import DefaultWorkflowRunner + + runner = DefaultWorkflowRunner() + + self._runner: WorkflowRunner = runner + self._dag_ctx: DAGContext = None + + @property + def current_dag_context(self) -> DAGContext: + return self._dag_ctx + + async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + if not self.node_id: + raise ValueError(f"The DAG Node ID can't be empty, current node {self}") + self._dag_ctx = dag_ctx + return await self._do_run(dag_ctx) + + @abstractmethod + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + """ + Abstract method to run the task within the DAG node. + + Args: + dag_ctx (DAGContext): The context of the DAG when this node is run. + + Returns: + TaskOutput[OUT]: The task output after this node has been run. + """ + + async def call(self, call_data: Optional[CALL_DATA] = None) -> OUT: + """Execute the node and return the output. + + This method is a high-level wrapper for executing the node. + + Args: + call_data (CALL_DATA): The data pass to root operator node. + + Returns: + OUT: The output of the node after execution. + """ + out_ctx = await self._runner.execute_workflow(self, call_data) + return out_ctx.current_task_context.task_output.output + + async def call_stream( + self, call_data: Optional[CALL_DATA] = None + ) -> AsyncIterator[OUT]: + """Execute the node and return the output as a stream. + + This method is used for nodes where the output is a stream. + + Args: + call_data (CALL_DATA): The data pass to root operator node. + + Returns: + AsyncIterator[OUT]: An asynchronous iterator over the output stream. + """ + out_ctx = await self._runner.execute_workflow(self, call_data) + return out_ctx.current_task_context.task_output.output_stream + + +def initialize_awel(runner: WorkflowRunner): + global default_runner + default_runner = runner diff --git a/pilot/awel/operator/common_operator.py b/pilot/awel/operator/common_operator.py new file mode 100644 index 000000000..6d12565aa --- /dev/null +++ b/pilot/awel/operator/common_operator.py @@ -0,0 +1,239 @@ +from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable +import asyncio +import logging + +from ..dag.base import DAGContext +from ..task.base import ( + TaskContext, + TaskOutput, + IN, + OUT, + InputContext, + InputSource, +) + +from .base import BaseOperator + + +logger = logging.getLogger(__name__) + + +class JoinOperator(BaseOperator, Generic[OUT]): + """Operator that joins inputs using a custom combine function. + + This node type is useful for combining the outputs of upstream nodes. + """ + + def __init__(self, combine_function, **kwargs): + super().__init__(**kwargs) + if not callable(combine_function): + raise ValueError("combine_function must be callable") + self.combine_function = combine_function + + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + """Run the join operation on the DAG context's inputs. + Args: + dag_ctx (DAGContext): The current context of the DAG. + + Returns: + TaskOutput[OUT]: The task output after this node has been run. + """ + curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + input_ctx: InputContext = await curr_task_ctx.task_input.map_all( + self.combine_function + ) + # All join result store in the first parent output + join_output = input_ctx.parent_outputs[0].task_output + curr_task_ctx.set_task_output(join_output) + return join_output + + +class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]): + def __init__(self, reduce_function=None, **kwargs): + """Initializes a ReduceStreamOperator with a combine function. + + Args: + combine_function: A function that defines how to combine inputs. + + Raises: + ValueError: If the combine_function is not callable. + """ + super().__init__(**kwargs) + if reduce_function and not callable(reduce_function): + raise ValueError("reduce_function must be callable") + self.reduce_function = reduce_function + + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + """Run the join operation on the DAG context's inputs. + + Args: + dag_ctx (DAGContext): The current context of the DAG. + + Returns: + TaskOutput[OUT]: The task output after this node has been run. + """ + curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + task_input = curr_task_ctx.task_input + if not task_input.check_stream(): + raise ValueError("ReduceStreamOperator expects stream data") + if not task_input.check_single_parent(): + raise ValueError("ReduceStreamOperator expects single parent") + + reduce_function = self.reduce_function or self.reduce + + input_ctx: InputContext = await task_input.reduce(reduce_function) + # All join result store in the first parent output + reduce_output = input_ctx.parent_outputs[0].task_output + curr_task_ctx.set_task_output(reduce_output) + return reduce_output + + async def reduce(self, input_value: AsyncIterator[IN]) -> OUT: + raise NotImplementedError + + +class MapOperator(BaseOperator, Generic[IN, OUT]): + """Map operator that applies a mapping function to its inputs. + + This operator transforms its input data using a provided mapping function and + passes the transformed data downstream. + """ + + def __init__(self, map_function=None, **kwargs): + """Initializes a MapDAGNode with a mapping function. + + Args: + map_function: A function that defines how to map the input data. + + Raises: + ValueError: If the map_function is not callable. + """ + super().__init__(**kwargs) + if map_function and not callable(map_function): + raise ValueError("map_function must be callable") + self.map_function = map_function + + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + """Run the mapping operation on the DAG context's inputs. + + This method applies the mapping function to the input context and updates + the DAG context with the new data. + + Args: + dag_ctx (DAGContext[IN]): The current context of the DAG. + + Returns: + TaskOutput[OUT]: The task output after this node has been run. + + Raises: + ValueError: If not a single parent or the map_function is not callable + """ + curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + if not curr_task_ctx.task_input.check_single_parent(): + num_parents = len(curr_task_ctx.task_input.parent_outputs) + raise ValueError( + f"task {curr_task_ctx.task_id} MapDAGNode expects single parent, now number of parents: {num_parents}" + ) + map_function = self.map_function or self.map + + input_ctx: InputContext = await curr_task_ctx.task_input.map(map_function) + # All join result store in the first parent output + reduce_output = input_ctx.parent_outputs[0].task_output + curr_task_ctx.set_task_output(reduce_output) + return reduce_output + + async def map(self, input_value: IN) -> OUT: + raise NotImplementedError + + +BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]] + + +class BranchOperator(BaseOperator, Generic[IN, OUT]): + """Operator node that branches the workflow based on a provided function. + + This node filters its input data using a branching function and + allows for conditional paths in the workflow. + """ + + def __init__( + self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs + ): + """ + Initializes a BranchDAGNode with a branching function. + + Args: + branches (Dict[BranchFunc[IN], Union[BaseOperator, str]]): Dict of function that defines the branching condition. + + Raises: + ValueError: If the branch_function is not callable. + """ + super().__init__(**kwargs) + if branches: + for branch_function, value in branches.items(): + if not callable(branch_function): + raise ValueError("branch_function must be callable") + if isinstance(value, BaseOperator): + branches[branch_function] = value.node_name or value.node_name + self._branches = branches + + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + """Run the branching operation on the DAG context's inputs. + + This method applies the branching function to the input context to determine + the path of execution in the workflow. + + Args: + dag_ctx (DAGContext[IN]): The current context of the DAG. + + Returns: + TaskOutput[OUT]: The task output after this node has been run. + """ + curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + task_input = curr_task_ctx.task_input + if task_input.check_stream(): + raise ValueError("BranchDAGNode expects no stream data") + if not task_input.check_single_parent(): + raise ValueError("BranchDAGNode expects single parent") + + branches = self._branches + if not branches: + branches = await self.branchs() + + branch_func_tasks = [] + branch_nodes: List[str] = [] + for func, node_name in branches.items(): + branch_nodes.append(node_name) + branch_func_tasks.append( + curr_task_ctx.task_input.predicate_map(func, failed_value=None) + ) + + branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks) + parent_output = task_input.parent_outputs[0].task_output + curr_task_ctx.set_task_output(parent_output) + skip_node_names = [] + for i, ctx in enumerate(branch_input_ctxs): + node_name = branch_nodes[i] + branch_out = ctx.parent_outputs[0].task_output + logger.info( + f"branch_input_ctxs {i} result {branch_out.output}, is_empty: {branch_out.is_empty}" + ) + if ctx.parent_outputs[0].task_output.is_empty: + logger.info(f"Skip node name {node_name}") + skip_node_names.append(node_name) + curr_task_ctx.update_metadata("skip_node_names", skip_node_names) + return parent_output + + async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]: + raise NotImplementedError + + +class InputOperator(BaseOperator, Generic[OUT]): + def __init__(self, input_source: InputSource[OUT], **kwargs) -> None: + super().__init__(**kwargs) + self._input_source = input_source + + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + task_output = await self._input_source.read(curr_task_ctx) + curr_task_ctx.set_task_output(task_output) + return task_output diff --git a/pilot/awel/operator/stream_operator.py b/pilot/awel/operator/stream_operator.py new file mode 100644 index 000000000..7de916a83 --- /dev/null +++ b/pilot/awel/operator/stream_operator.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod +from typing import Generic, AsyncIterator +from ..task.base import OUT, IN, TaskOutput, TaskContext +from ..dag.base import DAGContext +from .base import BaseOperator + + +class StreamifyAbsOperator(BaseOperator[OUT], ABC, Generic[IN, OUT]): + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + output = await curr_task_ctx.task_input.parent_outputs[0].task_output.streamify( + self.streamify + ) + curr_task_ctx.set_task_output(output) + return output + + @abstractmethod + async def streamify(self, input_value: IN) -> AsyncIterator[OUT]: + """Convert a value of IN to an AsyncIterator[OUT] + + Args: + input_value (IN): The data of parent operator's output + + Example: + + .. code-block:: python + + class MyStreamOperator(StreamifyAbsOperator[int, int]): + async def streamify(self, input_value: int) -> AsyncIterator[int] + for i in range(input_value): + yield i + """ + + +class UnstreamifyAbsOperator(BaseOperator[OUT], Generic[IN, OUT]): + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + output = await curr_task_ctx.task_input.parent_outputs[ + 0 + ].task_output.unstreamify(self.unstreamify) + curr_task_ctx.set_task_output(output) + return output + + @abstractmethod + async def unstreamify(self, input_value: AsyncIterator[IN]) -> OUT: + """Convert a value of AsyncIterator[IN] to an OUT. + + Args: + input_value (AsyncIterator[IN])): The data of parent operator's output + + Example: + + .. code-block:: python + + class MyUnstreamOperator(UnstreamifyAbsOperator[int, int]): + async def unstreamify(self, input_value: AsyncIterator[int]) -> int + value_cnt = 0 + async for v in input_value: + value_cnt += 1 + return value_cnt + """ + + +class TransformStreamAbsOperator(BaseOperator[OUT], Generic[IN, OUT]): + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: + curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context + output = await curr_task_ctx.task_input.parent_outputs[ + 0 + ].task_output.transform_stream(self.transform_stream) + curr_task_ctx.set_task_output(output) + return output + + @abstractmethod + async def transform_stream( + self, input_value: AsyncIterator[IN] + ) -> AsyncIterator[OUT]: + """Transform an AsyncIterator[IN] to another AsyncIterator[OUT] using a given function. + + Args: + input_value (AsyncIterator[IN])): The data of parent operator's output + + Example: + + .. code-block:: python + + class MyTransformStreamOperator(TransformStreamAbsOperator[int, int]): + async def unstreamify(self, input_value: AsyncIterator[int]) -> AsyncIterator[int] + async for v in input_value: + yield v + 1 + """ diff --git a/pilot/awel/resource/__init__.py b/pilot/awel/resource/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/awel/resource/base.py b/pilot/awel/resource/base.py new file mode 100644 index 000000000..97fefbbc3 --- /dev/null +++ b/pilot/awel/resource/base.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod + + +class ResourceGroup(ABC): + @property + @abstractmethod + def name(self) -> str: + """The name of current resource group""" diff --git a/pilot/awel/runner/__init__.py b/pilot/awel/runner/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/awel/runner/job_manager.py b/pilot/awel/runner/job_manager.py new file mode 100644 index 000000000..7a1d12ead --- /dev/null +++ b/pilot/awel/runner/job_manager.py @@ -0,0 +1,82 @@ +from typing import List, Set, Optional, Dict +import uuid +import logging +from ..dag.base import DAG + +from ..operator.base import BaseOperator, CALL_DATA + +logger = logging.getLogger(__name__) + + +class DAGNodeInstance: + def __init__(self, node_instance: DAG) -> None: + pass + + +class DAGInstance: + def __init__(self, dag: DAG) -> None: + self._dag = dag + + +class JobManager: + def __init__( + self, + root_nodes: List[BaseOperator], + all_nodes: List[BaseOperator], + end_node: BaseOperator, + id2call_data: Dict[str, Dict], + ) -> None: + self._root_nodes = root_nodes + self._all_nodes = all_nodes + self._end_node = end_node + self._id2node_data = id2call_data + + @staticmethod + def build_from_end_node( + end_node: BaseOperator, call_data: Optional[CALL_DATA] = None + ) -> "JobManager": + nodes = _build_from_end_node(end_node) + root_nodes = _get_root_nodes(nodes) + id2call_data = _save_call_data(root_nodes, call_data) + return JobManager(root_nodes, nodes, end_node, id2call_data) + + def get_call_data_by_id(self, node_id: str) -> Optional[Dict]: + return self._id2node_data.get(node_id) + + +def _save_call_data( + root_nodes: List[BaseOperator], call_data: CALL_DATA +) -> Dict[str, Dict]: + id2call_data = {} + logger.debug(f"_save_call_data: {call_data}, root_nodes: {root_nodes}") + if not call_data: + return id2call_data + if len(root_nodes) == 1: + node = root_nodes[0] + logger.info(f"Save call data to node {node.node_id}, call_data: {call_data}") + id2call_data[node.node_id] = call_data + else: + for node in root_nodes: + node_id = node.node_id + logger.info( + f"Save call data to node {node.node_id}, call_data: {call_data.get(node_id)}" + ) + id2call_data[node_id] = call_data.get(node_id) + return id2call_data + + +def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]: + nodes = [] + if isinstance(end_node, BaseOperator): + task_id = end_node.node_id + if not task_id: + task_id = str(uuid.uuid4()) + end_node.set_node_id(task_id) + nodes.append(end_node) + for node in end_node.upstream: + nodes += _build_from_end_node(node) + return nodes + + +def _get_root_nodes(nodes: List[BaseOperator]) -> List[BaseOperator]: + return list(set(filter(lambda x: not x.upstream, nodes))) diff --git a/pilot/awel/runner/local_runner.py b/pilot/awel/runner/local_runner.py new file mode 100644 index 000000000..769223212 --- /dev/null +++ b/pilot/awel/runner/local_runner.py @@ -0,0 +1,106 @@ +from typing import Dict, Optional, Set, List +import logging + +from ..dag.base import DAGContext +from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA +from ..operator.common_operator import BranchOperator, JoinOperator +from ..task.base import TaskContext, TaskState +from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput +from .job_manager import JobManager + +logger = logging.getLogger(__name__) + + +class DefaultWorkflowRunner(WorkflowRunner): + async def execute_workflow( + self, node: BaseOperator, call_data: Optional[CALL_DATA] = None + ) -> DAGContext: + # Create DAG context + dag_ctx = DAGContext() + job_manager = JobManager.build_from_end_node(node, call_data) + logger.info( + f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}" + ) + dag = node.dag + # Save node output + node_outputs: Dict[str, TaskContext] = {} + skip_node_ids = set() + await self._execute_node( + job_manager, node, dag_ctx, node_outputs, skip_node_ids + ) + + return dag_ctx + + async def _execute_node( + self, + job_manager: JobManager, + node: BaseOperator, + dag_ctx: DAGContext, + node_outputs: Dict[str, TaskContext], + skip_node_ids: Set[str], + ): + # Skip run node + if node.node_id in node_outputs: + return + + # Run all upstream node + for upstream_node in node.upstream: + if isinstance(upstream_node, BaseOperator): + await self._execute_node( + job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids + ) + + inputs = [ + node_outputs[upstream_node.node_id] for upstream_node in node.upstream + ] + input_ctx = DefaultInputContext(inputs) + task_ctx = DefaultTaskContext(node.node_id, TaskState.INIT, task_output=None) + task_ctx.set_call_data(job_manager.get_call_data_by_id(node.node_id)) + + task_ctx.set_task_input(input_ctx) + dag_ctx.set_current_task_context(task_ctx) + task_ctx.set_current_state(TaskState.RUNNING) + + if node.node_id in skip_node_ids: + task_ctx.set_current_state(TaskState.SKIP) + task_ctx.set_task_output(SimpleTaskOutput(None)) + node_outputs[node.node_id] = task_ctx + return + try: + logger.info( + f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}" + ) + await node._run(dag_ctx) + node_outputs[node.node_id] = dag_ctx.current_task_context + task_ctx.set_current_state(TaskState.SUCCESS) + + if isinstance(node, BranchOperator): + skip_nodes = task_ctx.metadata.get("skip_node_names", []) + logger.info( + f"Current is branch operator, skip node names: {skip_nodes}" + ) + _skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids) + except Exception as e: + logger.info(f"Run operator {node.node_id} error, error message: {str(e)}") + task_ctx.set_current_state(TaskState.FAILED) + raise e + + +def _skip_current_downstream_by_node_name( + branch_node: BranchOperator, skip_nodes: List[str], skip_node_ids: Set[str] +): + if not skip_nodes: + return + for child in branch_node.downstream: + if child.node_name in skip_nodes: + logger.info(f"Skip node name {child.node_name}, node id {child.node_id}") + _skip_downstream_by_id(child, skip_node_ids) + + +def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]): + if isinstance(node, JoinOperator): + # Not skip join node + return + skip_node_ids.add(node.node_id) + for child in node.downstream: + _skip_downstream_by_id(child, skip_node_ids) diff --git a/pilot/awel/task/__init__.py b/pilot/awel/task/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/awel/task/base.py b/pilot/awel/task/base.py new file mode 100644 index 000000000..88b0df343 --- /dev/null +++ b/pilot/awel/task/base.py @@ -0,0 +1,367 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import ( + TypeVar, + Generic, + Optional, + AsyncIterator, + Union, + Callable, + Any, + Dict, + List, +) + +IN = TypeVar("IN") +OUT = TypeVar("OUT") +T = TypeVar("T") + + +class TaskState(str, Enum): + """Enumeration representing the state of a task in the workflow. + + This Enum defines various states a task can be in during its lifecycle in the DAG. + """ + + INIT = "init" # Initial state of the task, not yet started + SKIP = "skip" # State indicating the task was skipped + RUNNING = "running" # State indicating the task is currently running + SUCCESS = "success" # State indicating the task completed successfully + FAILED = "failed" # State indicating the task failed during execution + + +class TaskOutput(ABC, Generic[T]): + """Abstract base class representing the output of a task. + + This class encapsulates the output of a task and provides methods to access the output data. + It can be subclassed to implement specific output behaviors. + """ + + @property + def is_stream(self) -> bool: + """Check if the output is a stream. + + Returns: + bool: True if the output is a stream, False otherwise. + """ + return False + + @property + def is_empty(self) -> bool: + """Check if the output is empty. + + Returns: + bool: True if the output is empty, False otherwise. + """ + return False + + @property + def output(self) -> Optional[T]: + """Return the output of the task. + + Returns: + T: The output of the task. None if the output is empty. + """ + raise NotImplementedError + + @property + def output_stream(self) -> Optional[AsyncIterator[T]]: + """Return the output of the task as an asynchronous stream. + + Returns: + AsyncIterator[T]: An asynchronous iterator over the output. None if the output is empty. + """ + raise NotImplementedError + + @abstractmethod + def set_output(self, output_data: Union[T, AsyncIterator[T]]) -> None: + """Set the output data to current object. + + Args: + output_data (Union[T, AsyncIterator[T]]): Output data. + """ + + @abstractmethod + def new_output(self) -> "TaskOutput[T]": + """Create new output object""" + + async def map(self, map_func) -> "TaskOutput[T]": + """Apply a mapping function to the task's output. + + Args: + map_func: A function to apply to the task's output. + + Returns: + TaskOutput[T]: The result of applying the mapping function. + """ + raise NotImplementedError + + async def reduce(self, reduce_func) -> "TaskOutput[T]": + """Apply a reducing function to the task's output. + + Stream TaskOutput to Nonstream TaskOutput. + + Args: + reduce_func: A reducing function to apply to the task's output. + + Returns: + TaskOutput[T]: The result of applying the reducing function. + """ + raise NotImplementedError + + async def streamify( + self, transform_func: Callable[[T], AsyncIterator[T]] + ) -> "TaskOutput[T]": + """Convert a value of type T to an AsyncIterator[T] using a transform function. + + Args: + transform_func (Callable[[T], AsyncIterator[T]]): Function to transform a T value into an AsyncIterator[T]. + + Returns: + TaskOutput[T]: The result of applying the reducing function. + """ + raise NotImplementedError + + async def transform_stream( + self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]] + ) -> "TaskOutput[T]": + """Transform an AsyncIterator[T] to another AsyncIterator[T] using a given function. + + Args: + transform_func (Callable[[AsyncIterator[T]], AsyncIterator[T]]): Function to apply to the AsyncIterator[T]. + + Returns: + TaskOutput[T]: The result of applying the reducing function. + """ + raise NotImplementedError + + async def unstreamify( + self, transform_func: Callable[[AsyncIterator[T]], T] + ) -> "TaskOutput[T]": + """Convert an AsyncIterator[T] to a value of type T using a transform function. + + Args: + transform_func (Callable[[AsyncIterator[T]], T]): Function to transform an AsyncIterator[T] into a T value. + + Returns: + TaskOutput[T]: The result of applying the reducing function. + """ + raise NotImplementedError + + async def check_condition(self, condition_func) -> bool: + """Check if current output meets a given condition. + + Args: + condition_func: A function to determine if the condition is met. + Returns: + bool: True if current output meet the condition, False otherwise. + """ + raise NotImplementedError + + +class TaskContext(ABC, Generic[T]): + """Abstract base class representing the context of a task within a DAG. + + This class provides the interface for accessing task-related information + and manipulating task output. + """ + + @property + @abstractmethod + def task_id(self) -> str: + """Return the unique identifier of the task. + + Returns: + str: The unique identifier of the task. + """ + + @property + @abstractmethod + def task_input(self) -> "InputContext": + """Return the InputContext of current task. + + Returns: + InputContext: The InputContext of current task. + """ + + @abstractmethod + def set_task_input(self, input_ctx: "InputContext") -> None: + """Set the InputContext object to current task. + + Args: + input_ctx (InputContext): The InputContext of current task + """ + + @property + @abstractmethod + def task_output(self) -> TaskOutput[T]: + """Return the output object of the task. + + Returns: + TaskOutput[T]: The output object of the task. + """ + + @abstractmethod + def set_task_output(self, task_output: TaskOutput[T]) -> None: + """Set the output object to current task.""" + + @property + @abstractmethod + def current_state(self) -> TaskState: + """Get the current state of the task. + + Returns: + TaskState: The current state of the task. + """ + + @abstractmethod + def set_current_state(self, task_state: TaskState) -> None: + """Set current task state + + Args: + task_state (TaskState): The task state to be set. + """ + + @abstractmethod + def new_ctx(self) -> "TaskContext": + """Create new task context + + Returns: + TaskContext: A new instance of a TaskContext. + """ + + @property + @abstractmethod + def metadata(self) -> Dict[str, Any]: + """Get the metadata of current task + + Returns: + Dict[str, Any]: The metadata + """ + + def update_metadata(self, key: str, value: Any) -> None: + """Update metadata with key and value + + Args: + key (str): The key of metadata + value (str): The value to be add to metadata + """ + self.metadata[key] = value + + @property + def call_data(self) -> Optional[Dict]: + """Get the call data for current data""" + return self.metadata.get("call_data") + + def set_call_data(self, call_data: Dict) -> None: + """Set call data for current task""" + self.update_metadata("call_data", call_data) + + +class InputContext(ABC): + """Abstract base class representing the context of inputs to a operator node. + + This class defines methods to manipulate and access the inputs for a operator node. + """ + + @property + @abstractmethod + def parent_outputs(self) -> List[TaskContext]: + """Get the outputs from the parent nodes. + + Returns: + List[TaskContext]: A list of contexts of the parent nodes' outputs. + """ + + @abstractmethod + async def map(self, map_func: Callable[[Any], Any]) -> "InputContext": + """Apply a mapping function to the inputs. + + Args: + map_func (Callable[[Any], Any]): A function to be applied to the inputs. + + Returns: + InputContext: A new InputContext instance with the mapped inputs. + """ + + @abstractmethod + async def map_all(self, map_func: Callable[..., Any]) -> "InputContext": + """Apply a mapping function to all inputs. + + Args: + map_func (Callable[..., Any]): A function to be applied to all inputs. + + Returns: + InputContext: A new InputContext instance with the mapped inputs. + """ + + @abstractmethod + async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext": + """Apply a reducing function to the inputs. + + Args: + reduce_func (Callable[[Any], Any]): A function that reduces the inputs. + + Returns: + InputContext: A new InputContext instance with the reduced inputs. + """ + + @abstractmethod + async def filter(self, filter_func: Callable[[Any], bool]) -> "InputContext": + """Filter the inputs based on a provided function. + + Args: + filter_func (Callable[[Any], bool]): A function that returns True for inputs to keep. + + Returns: + InputContext: A new InputContext instance with the filtered inputs. + """ + + @abstractmethod + async def predicate_map( + self, predicate_func: Callable[[Any], bool], failed_value: Any = None + ) -> "InputContext": + """Predicate the inputs based on a provided function. + + Args: + predicate_func (Callable[[Any], bool]): A function that returns True for inputs is predicate True. + failed_value (Any): The value to be set if the return value of predicate function is False + Returns: + InputContext: A new InputContext instance with the predicate inputs. + """ + + def check_single_parent(self) -> bool: + """Check if there is only a single parent output. + + Returns: + bool: True if there is only one parent output, False otherwise. + """ + return len(self.parent_outputs) == 1 + + def check_stream(self, skip_empty: bool = False) -> bool: + """Check if all parent outputs are streams. + + Args: + skip_empty (bool): Skip empty output or not. + + Returns: + bool: True if all parent outputs are streams, False otherwise. + """ + for out in self.parent_outputs: + if out.task_output.is_empty and skip_empty: + continue + if not (out.task_output and out.task_output.is_stream): + return False + return True + + +class InputSource(ABC, Generic[T]): + """Abstract base class representing the source of inputs to a DAG node.""" + + @abstractmethod + async def read(self, task_ctx: TaskContext) -> TaskOutput[T]: + """Read the data from current input source. + + Returns: + TaskOutput[T]: The output object read from current source + """ diff --git a/pilot/awel/task/task_impl.py b/pilot/awel/task/task_impl.py new file mode 100644 index 000000000..f969c135c --- /dev/null +++ b/pilot/awel/task/task_impl.py @@ -0,0 +1,339 @@ +from abc import ABC, abstractmethod +from typing import ( + Callable, + Coroutine, + Iterator, + AsyncIterator, + List, + Generic, + TypeVar, + Any, + Tuple, + Dict, + Union, +) +import asyncio +import logging +from .base import TaskOutput, TaskContext, TaskState, InputContext, InputSource, T + + +logger = logging.getLogger(__name__) + + +async def _reduce_stream(stream: AsyncIterator, reduce_function) -> Any: + # Init accumulator + try: + accumulator = await stream.__anext__() + except StopAsyncIteration: + raise ValueError("Stream is empty") + is_async = asyncio.iscoroutinefunction(reduce_function) + async for element in stream: + if is_async: + accumulator = await reduce_function(accumulator, element) + else: + accumulator = reduce_function(accumulator, element) + return accumulator + + +class SimpleTaskOutput(TaskOutput[T], Generic[T]): + def __init__(self, data: T) -> None: + super().__init__() + self._data = data + + @property + def output(self) -> T: + return self._data + + def set_output(self, output_data: T | AsyncIterator[T]) -> None: + self._data = output_data + + def new_output(self) -> TaskOutput[T]: + return SimpleTaskOutput(None) + + @property + def is_empty(self) -> bool: + return not self._data + + async def _apply_func(self, func) -> Any: + if asyncio.iscoroutinefunction(func): + out = await func(self._data) + else: + out = func(self._data) + return out + + async def map(self, map_func) -> TaskOutput[T]: + out = await self._apply_func(map_func) + return SimpleTaskOutput(out) + + async def check_condition(self, condition_func) -> bool: + return await self._apply_func(condition_func) + + async def streamify( + self, transform_func: Callable[[T], AsyncIterator[T]] + ) -> TaskOutput[T]: + out = await self._apply_func(transform_func) + return SimpleStreamTaskOutput(out) + + +class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]): + def __init__(self, data: AsyncIterator[T]) -> None: + super().__init__() + self._data = data + + @property + def is_stream(self) -> bool: + return True + + @property + def is_empty(self) -> bool: + return not self._data + + @property + def output_stream(self) -> AsyncIterator[T]: + return self._data + + def set_output(self, output_data: T | AsyncIterator[T]) -> None: + self._data = output_data + + def new_output(self) -> TaskOutput[T]: + return SimpleStreamTaskOutput(None) + + async def map(self, map_func) -> TaskOutput[T]: + is_async = asyncio.iscoroutinefunction(map_func) + + async def new_iter() -> AsyncIterator[T]: + async for out in self._data: + if is_async: + out = await map_func(out) + else: + out = map_func(out) + yield out + + return SimpleStreamTaskOutput(new_iter()) + + async def reduce(self, reduce_func) -> TaskOutput[T]: + out = await _reduce_stream(self._data, reduce_func) + return SimpleTaskOutput(out) + + async def unstreamify( + self, transform_func: Callable[[AsyncIterator[T]], T] + ) -> TaskOutput[T]: + if asyncio.iscoroutinefunction(transform_func): + out = await transform_func(self._data) + else: + out = transform_func(self._data) + return SimpleTaskOutput(out) + + async def transform_stream( + self, transform_func: Callable[[AsyncIterator[T]], AsyncIterator[T]] + ) -> TaskOutput[T]: + if asyncio.iscoroutinefunction(transform_func): + out = await transform_func(self._data) + else: + out = transform_func(self._data) + return SimpleStreamTaskOutput(out) + + +def _is_async_iterator(obj): + return ( + hasattr(obj, "__anext__") + and callable(getattr(obj, "__anext__", None)) + and hasattr(obj, "__aiter__") + and callable(getattr(obj, "__aiter__", None)) + ) + + +class BaseInputSource(InputSource, ABC): + def __init__(self) -> None: + super().__init__() + self._is_read = False + + @abstractmethod + def _read_data(self, task_ctx: TaskContext) -> Any: + """Read data with task context""" + + async def read(self, task_ctx: TaskContext) -> Coroutine[Any, Any, TaskOutput]: + data = self._read_data(task_ctx) + if _is_async_iterator(data): + if self._is_read: + raise ValueError(f"Input iterator {data} has been read!") + output = SimpleStreamTaskOutput(data) + else: + output = SimpleTaskOutput(data) + self._is_read = True + return output + + +class SimpleInputSource(BaseInputSource): + def __init__(self, data: Any) -> None: + super().__init__() + self._data = data + + def _read_data(self, task_ctx: TaskContext) -> Any: + return self._data + + +class SimpleCallDataInputSource(BaseInputSource): + def __init__(self) -> None: + super().__init__() + + def _read_data(self, task_ctx: TaskContext) -> Any: + call_data = task_ctx.call_data + data = call_data.get("data") if call_data else None + if not (call_data and data): + raise ValueError("No call data for current SimpleCallDataInputSource") + return data + + +class DefaultTaskContext(TaskContext, Generic[T]): + def __init__( + self, task_id: str, task_state: TaskState, task_output: TaskOutput[T] + ) -> None: + super().__init__() + self._task_id = task_id + self._task_state = task_state + self._output = task_output + self._task_input = None + self._metadata = {} + + @property + def task_id(self) -> str: + return self._task_id + + @property + def task_input(self) -> InputContext: + return self._task_input + + def set_task_input(self, input_ctx: "InputContext") -> None: + self._task_input = input_ctx + + @property + def task_output(self) -> TaskOutput: + return self._output + + def set_task_output(self, task_output: TaskOutput) -> None: + self._output = task_output + + @property + def current_state(self) -> TaskState: + return self._task_state + + def set_current_state(self, task_state: TaskState) -> None: + self._task_state = task_state + + def new_ctx(self) -> TaskContext: + new_output = self._output.new_output() + return DefaultTaskContext(self._task_id, self._task_state, new_output) + + @property + def metadata(self) -> Dict[str, Any]: + return self._metadata + + +class DefaultInputContext(InputContext): + def __init__(self, outputs: List[TaskContext]) -> None: + super().__init__() + self._outputs = outputs + + @property + def parent_outputs(self) -> List[TaskContext]: + return self._outputs + + async def _apply_func( + self, func: Callable[[Any], Any], apply_type: str = "map" + ) -> Tuple[List[TaskContext], List[TaskOutput]]: + new_outputs: List[TaskContext] = [] + map_tasks = [] + for out in self._outputs: + new_outputs.append(out.new_ctx()) + result = None + if apply_type == "map": + result = out.task_output.map(func) + elif apply_type == "reduce": + result = out.task_output.reduce(func) + elif apply_type == "check_condition": + result = out.task_output.check_condition(func) + else: + raise ValueError(f"Unsupport apply type {apply_type}") + map_tasks.append(result) + results = await asyncio.gather(*map_tasks) + return new_outputs, results + + async def map(self, map_func: Callable[[Any], Any]) -> InputContext: + new_outputs, results = await self._apply_func(map_func) + for i, task_ctx in enumerate(new_outputs): + task_ctx: TaskContext = task_ctx + task_ctx.set_task_output(results[i]) + return DefaultInputContext(new_outputs) + + async def map_all(self, map_func: Callable[..., Any]) -> InputContext: + if not self._outputs: + return DefaultInputContext([]) + # Some parent may be empty + not_empty_idx = 0 + for i, p in enumerate(self._outputs): + if p.task_output.is_empty: + continue + not_empty_idx = i + break + # All output is empty? + is_steam = self._outputs[not_empty_idx].task_output.is_stream + if is_steam: + if not self.check_stream(skip_empty=True): + raise ValueError( + "The output in all tasks must has same output format to map_all" + ) + outputs = [] + for out in self._outputs: + if out.task_output.is_stream: + outputs.append(out.task_output.output_stream) + else: + outputs.append(out.task_output.output) + if asyncio.iscoroutinefunction(map_func): + map_res = await map_func(*outputs) + else: + map_res = map_func(*outputs) + single_output: TaskContext = self._outputs[not_empty_idx].new_ctx() + single_output.task_output.set_output(map_res) + logger.debug( + f"Current map_all map_res: {map_res}, is steam: {single_output.task_output.is_stream}" + ) + return DefaultInputContext([single_output]) + + async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext: + if not self.check_stream(): + raise ValueError( + "The output in all tasks must has same output format of stream to apply reduce function" + ) + new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce") + for i, task_ctx in enumerate(new_outputs): + task_ctx: TaskContext = task_ctx + task_ctx.set_task_output(results[i]) + return DefaultInputContext(new_outputs) + + async def filter(self, filter_func: Callable[[Any], bool]) -> InputContext: + new_outputs, results = await self._apply_func( + filter_func, apply_type="check_condition" + ) + result_outputs = [] + for i, task_ctx in enumerate(new_outputs): + if results[i]: + result_outputs.append(task_ctx) + return DefaultInputContext(result_outputs) + + async def predicate_map( + self, predicate_func: Callable[[Any], bool], failed_value: Any = None + ) -> "InputContext": + new_outputs, results = await self._apply_func( + predicate_func, apply_type="check_condition" + ) + result_outputs = [] + for i, task_ctx in enumerate(new_outputs): + task_ctx: TaskContext = task_ctx + if results[i]: + task_ctx.task_output.set_output(True) + result_outputs.append(task_ctx) + else: + task_ctx.task_output.set_output(failed_value) + result_outputs.append(task_ctx) + return DefaultInputContext(result_outputs) diff --git a/pilot/awel/tests/__init__.py b/pilot/awel/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/awel/tests/conftest.py b/pilot/awel/tests/conftest.py new file mode 100644 index 000000000..2279cceba --- /dev/null +++ b/pilot/awel/tests/conftest.py @@ -0,0 +1,102 @@ +import pytest +import pytest_asyncio +from typing import AsyncIterator, List +from contextlib import contextmanager, asynccontextmanager +from .. import ( + WorkflowRunner, + InputOperator, + DAGContext, + TaskState, + DefaultWorkflowRunner, + SimpleInputSource, +) +from ..task.task_impl import _is_async_iterator + + +@pytest.fixture +def runner(): + return DefaultWorkflowRunner() + + +def _create_stream(num_nodes) -> List[AsyncIterator[int]]: + iters = [] + for _ in range(num_nodes): + + async def stream_iter(): + for i in range(10): + yield i + + stream_iter = stream_iter() + assert _is_async_iterator(stream_iter) + iters.append(stream_iter) + return iters + + +def _create_stream_from(output_streams: List[List[int]]) -> List[AsyncIterator[int]]: + iters = [] + for single_stream in output_streams: + + async def stream_iter(): + for i in single_stream: + yield i + + stream_iter = stream_iter() + assert _is_async_iterator(stream_iter) + iters.append(stream_iter) + return iters + + +@asynccontextmanager +async def _create_input_node(**kwargs): + num_nodes = kwargs.get("num_nodes") + is_stream = kwargs.get("is_stream", False) + if is_stream: + outputs = kwargs.get("output_streams") + if outputs: + if num_nodes and num_nodes != len(outputs): + raise ValueError( + f"num_nodes {num_nodes} != the length of output_streams {len(outputs)}" + ) + outputs = _create_stream_from(outputs) + else: + num_nodes = num_nodes or 1 + outputs = _create_stream(num_nodes) + else: + outputs = kwargs.get("outputs", ["Hello."]) + nodes = [] + for output in outputs: + print(f"output: {output}") + input_source = SimpleInputSource(output) + input_node = InputOperator(input_source) + nodes.append(input_node) + yield nodes + + +@pytest_asyncio.fixture +async def input_node(request): + param = getattr(request, "param", {}) + async with _create_input_node(**param) as input_nodes: + yield input_nodes[0] + + +@pytest_asyncio.fixture +async def stream_input_node(request): + param = getattr(request, "param", {}) + param["is_stream"] = True + async with _create_input_node(**param) as input_nodes: + yield input_nodes[0] + + +@pytest_asyncio.fixture +async def input_nodes(request): + param = getattr(request, "param", {}) + async with _create_input_node(**param) as input_nodes: + yield input_nodes + + +@pytest_asyncio.fixture +async def stream_input_nodes(request): + param = getattr(request, "param", {}) + param["is_stream"] = True + async with _create_input_node(**param) as input_nodes: + yield input_nodes diff --git a/pilot/awel/tests/test_http_operator.py b/pilot/awel/tests/test_http_operator.py new file mode 100644 index 000000000..c57e70fe1 --- /dev/null +++ b/pilot/awel/tests/test_http_operator.py @@ -0,0 +1,51 @@ +import pytest +from typing import List +from .. import ( + DAG, + WorkflowRunner, + DAGContext, + TaskState, + InputOperator, + MapOperator, + JoinOperator, + BranchOperator, + ReduceStreamOperator, + SimpleInputSource, +) +from .conftest import ( + runner, + input_node, + input_nodes, + stream_input_node, + stream_input_nodes, + _is_async_iterator, +) + + +def _register_dag_to_fastapi_app(dag): + # TODO + pass + + +@pytest.mark.asyncio +async def test_http_operator(runner: WorkflowRunner, stream_input_node: InputOperator): + with DAG("test_map") as dag: + pass + # http_req_task = HttpRequestOperator(endpoint="/api/completions") + # db_task = DBQueryOperator(table_name="user_info") + # prompt_task = PromptTemplateOperator( + # system_prompt="You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers." + # ) + # llm_task = ChatGPTLLMOperator(model="chagpt-3.5") + # output_parser_task = CommonOutputParserOperator() + # http_res_task = HttpResponseOperator() + # ( + # http_req_task + # >> db_task + # >> prompt_task + # >> llm_task + # >> output_parser_task + # >> http_res_task + # ) + + _register_dag_to_fastapi_app(dag) diff --git a/pilot/awel/tests/test_run_dag.py b/pilot/awel/tests/test_run_dag.py new file mode 100644 index 000000000..c0ea8e7ad --- /dev/null +++ b/pilot/awel/tests/test_run_dag.py @@ -0,0 +1,141 @@ +import pytest +from typing import List +from .. import ( + DAG, + WorkflowRunner, + DAGContext, + TaskState, + InputOperator, + MapOperator, + JoinOperator, + BranchOperator, + ReduceStreamOperator, + SimpleInputSource, +) +from .conftest import ( + runner, + input_node, + input_nodes, + stream_input_node, + stream_input_nodes, + _is_async_iterator, +) + + +@pytest.mark.asyncio +async def test_input_node(runner: WorkflowRunner): + input_node = InputOperator(SimpleInputSource("hello")) + res: DAGContext[str] = await runner.execute_workflow(input_node) + assert res.current_task_context.current_state == TaskState.SUCCESS + assert res.current_task_context.task_output.output == "hello" + + async def new_steam_iter(n: int): + for i in range(n): + yield i + + num_iter = 10 + steam_input_node = InputOperator(SimpleInputSource(new_steam_iter(num_iter))) + res: DAGContext[str] = await runner.execute_workflow(steam_input_node) + assert res.current_task_context.current_state == TaskState.SUCCESS + output_steam = res.current_task_context.task_output.output_stream + assert output_steam + assert _is_async_iterator(output_steam) + i = 0 + async for x in output_steam: + assert x == i + i += 1 + + +@pytest.mark.asyncio +async def test_map_node(runner: WorkflowRunner, stream_input_node: InputOperator): + with DAG("test_map") as dag: + map_node = MapOperator(lambda x: x * 2) + stream_input_node >> map_node + res: DAGContext[int] = await runner.execute_workflow(map_node) + output_steam = res.current_task_context.task_output.output_stream + assert output_steam + i = 0 + async for x in output_steam: + assert x == i * 2 + i += 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "stream_input_node, expect_sum", + [ + ({"output_streams": [[0, 1, 2, 3]]}, 6), + ({"output_streams": [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]}, 55), + ], + indirect=["stream_input_node"], +) +async def test_reduce_node( + runner: WorkflowRunner, stream_input_node: InputOperator, expect_sum: int +): + with DAG("test_reduce_node") as dag: + reduce_node = ReduceStreamOperator(lambda x, y: x + y) + stream_input_node >> reduce_node + res: DAGContext[int] = await runner.execute_workflow(reduce_node) + assert res.current_task_context.current_state == TaskState.SUCCESS + assert not res.current_task_context.task_output.is_stream + assert res.current_task_context.task_output.output == expect_sum + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_nodes", + [ + ({"outputs": [0, 1, 2]}), + ], + indirect=["input_nodes"], +) +async def test_join_node(runner: WorkflowRunner, input_nodes: List[InputOperator]): + def join_func(p1, p2, p3) -> int: + return p1 + p2 + p3 + + with DAG("test_join_node") as dag: + join_node = JoinOperator(join_func) + for input_node in input_nodes: + input_node >> join_node + res: DAGContext[int] = await runner.execute_workflow(join_node) + assert res.current_task_context.current_state == TaskState.SUCCESS + assert not res.current_task_context.task_output.is_stream + assert res.current_task_context.task_output.output == 3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "input_node, is_odd", + [ + ({"outputs": [0]}, False), + ({"outputs": [1]}, True), + ], + indirect=["input_node"], +) +async def test_branch_node( + runner: WorkflowRunner, input_node: InputOperator, is_odd: bool +): + def join_func(o1, o2) -> int: + print(f"join func result, o1: {o1}, o2: {o2}") + return o1 or o2 + + with DAG("test_join_node") as dag: + odd_node = MapOperator( + lambda x: 999, task_id="odd_node", task_name="odd_node_name" + ) + even_node = MapOperator( + lambda x: 888, task_id="even_node", task_name="even_node_name" + ) + join_node = JoinOperator(join_func) + branch_node = BranchOperator( + {lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node} + ) + branch_node >> odd_node >> join_node + branch_node >> even_node >> join_node + + input_node >> branch_node + + res: DAGContext[int] = await runner.execute_workflow(join_node) + assert res.current_task_context.current_state == TaskState.SUCCESS + expect_res = 999 if is_odd else 888 + assert res.current_task_context.task_output.output == expect_res diff --git a/pilot/cache/__init__.py b/pilot/cache/__init__.py new file mode 100644 index 000000000..65f768a7e --- /dev/null +++ b/pilot/cache/__init__.py @@ -0,0 +1,10 @@ +from pilot.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue +from pilot.cache.manager import CacheManager, initialize_cache + +__all__ = [ + "LLMCacheKey", + "LLMCacheValue", + "LLMCacheClient", + "CacheManager", + "initialize_cache", +] diff --git a/pilot/cache/base.py b/pilot/cache/base.py new file mode 100644 index 000000000..feb135288 --- /dev/null +++ b/pilot/cache/base.py @@ -0,0 +1,161 @@ +from abc import ABC, abstractmethod, abstractclassmethod +from typing import Any, TypeVar, Generic, Optional, Type, Dict +from dataclasses import dataclass +from enum import Enum + +T = TypeVar("T", bound="Serializable") + +K = TypeVar("K") +V = TypeVar("V") + + +class Serializable(ABC): + @abstractmethod + def serialize(self) -> bytes: + """Convert the object into bytes for storage or transmission. + + Returns: + bytes: The byte array after serialization + """ + + @abstractmethod + def to_dict(self) -> Dict: + """Convert the object's state to a dictionary.""" + + # @staticmethod + # @abstractclassmethod + # def from_dict(cls: Type["Serializable"], obj_dict: Dict) -> "Serializable": + # """Deserialize a dictionary to an Serializable object. + # """ + + +class RetrievalPolicy(str, Enum): + EXACT_MATCH = "exact_match" + SIMILARITY_MATCH = "similarity_match" + + +class CachePolicy(str, Enum): + LRU = "lru" + FIFO = "fifo" + + +@dataclass +class CacheConfig: + retrieval_policy: Optional[RetrievalPolicy] = RetrievalPolicy.EXACT_MATCH + cache_policy: Optional[CachePolicy] = CachePolicy.LRU + + +class CacheKey(Serializable, ABC, Generic[K]): + """The key of the cache. Must be hashable and comparable. + + Supported cache keys: + - The LLM cache key: Include user prompt and the parameters to LLM. + - The embedding model cache key: Include the texts to embedding and the parameters to embedding model. + """ + + @abstractmethod + def __hash__(self) -> int: + """Return the hash value of the key.""" + + @abstractmethod + def __eq__(self, other: Any) -> bool: + """Check equality with another key.""" + + @abstractmethod + def get_hash_bytes(self) -> bytes: + """Return the byte array of hash value.""" + + @abstractmethod + def get_value(self) -> K: + """Get the underlying value of the cache key. + + Returns: + K: The real object of current cache key + """ + + +class CacheValue(Serializable, ABC, Generic[V]): + """Cache value abstract class.""" + + @abstractmethod + def get_value(self) -> V: + """Get the underlying real value.""" + + +class Serializer(ABC): + """The serializer abstract class for serializing cache keys and values.""" + + @abstractmethod + def serialize(self, obj: Serializable) -> bytes: + """Serialize a cache object. + + Args: + obj (Serializable): The object to serialize + """ + + @abstractmethod + def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable: + """Deserialize data back into a cache object of the specified type. + + Args: + data (bytes): The byte array to deserialize + cls (Type[Serializable]): The type of current object + + Returns: + Serializable: The serializable object + """ + + +class CacheClient(ABC, Generic[K, V]): + """The cache client interface.""" + + @abstractmethod + async def get( + self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None + ) -> Optional[CacheValue[V]]: + """Retrieve a value from the cache using the provided key. + + Args: + key (CacheKey[K]): The key to get cache + cache_config (Optional[CacheConfig]): Cache config + + Returns: + Optional[CacheValue[V]]: The value retrieved according to key. If cache key not exist, return None. + """ + + @abstractmethod + async def set( + self, + key: CacheKey[K], + value: CacheValue[V], + cache_config: Optional[CacheConfig] = None, + ) -> None: + """Set a value in the cache for the provided key. + + Args: + key (CacheKey[K]): The key to set to cache + value (CacheValue[V]): The value to set to cache + cache_config (Optional[CacheConfig]): Cache config + """ + + @abstractmethod + async def exists( + self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None + ) -> bool: + """Check if a key exists in the cache. + + Args: + key (CacheKey[K]): The key to set to cache + cache_config (Optional[CacheConfig]): Cache config + + Return: + bool: True if the key in the cache, otherwise is False + """ + + @abstractmethod + def new_key(self, **kwargs) -> CacheKey[K]: + """Create a cache key with params""" + + @abstractmethod + def new_value(self, **kwargs) -> CacheValue[K]: + """Create a cache key with params""" diff --git a/pilot/cache/embedding_cache.py b/pilot/cache/embedding_cache.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/cache/llm_cache.py b/pilot/cache/llm_cache.py new file mode 100644 index 000000000..ad559df03 --- /dev/null +++ b/pilot/cache/llm_cache.py @@ -0,0 +1,148 @@ +from typing import Optional, Dict, Any, Union, List +from dataclasses import dataclass, asdict +import json +import hashlib + +from pilot.cache.base import CacheKey, CacheValue, Serializer, CacheClient, CacheConfig +from pilot.cache.manager import CacheManager +from pilot.cache.storage.base import CacheStorage +from pilot.model.base import ModelType, ModelOutput + + +@dataclass +class LLMCacheKeyData: + prompt: str + model_name: str + temperature: Optional[float] = 0.7 + max_new_tokens: Optional[int] = None + top_p: Optional[float] = 1.0 + model_type: Optional[str] = ModelType.HF + + +CacheOutputType = Union[ModelOutput, List[ModelOutput]] + + +@dataclass +class LLMCacheValueData: + output: CacheOutputType + user: Optional[str] = None + _is_list: Optional[bool] = False + + @staticmethod + def from_dict(**kwargs) -> "LLMCacheValueData": + output = kwargs.get("output") + if not output: + raise ValueError("Can't new LLMCacheValueData object, output is None") + if isinstance(output, dict): + output = ModelOutput(**output) + elif isinstance(output, list): + kwargs["_is_list"] = True + output_list = [] + for out in output: + if isinstance(out, dict): + out = ModelOutput(**out) + output_list.append(out) + output = output_list + kwargs["output"] = output + return LLMCacheValueData(**kwargs) + + def to_dict(self) -> Dict: + output = self.output + is_list = False + if isinstance(output, list): + output_list = [] + is_list = True + for out in output: + output_list.append(out.to_dict()) + output = output_list + else: + output = output.to_dict() + return {"output": output, "_is_list": is_list, "user": self.user} + + @property + def is_list(self) -> bool: + return self._is_list + + def __str__(self) -> str: + if not isinstance(self.output, list): + return f"user: {self.user}, output: {self.output}" + else: + return f"user: {self.user}, output(last two item): {self.output[-2:]}" + + +class LLMCacheKey(CacheKey[LLMCacheKeyData]): + def __init__(self, serializer: Serializer = None, **kwargs) -> None: + super().__init__() + self._serializer = serializer + self.config = LLMCacheKeyData(**kwargs) + + def __hash__(self) -> int: + serialize_bytes = self.serialize() + return int(hashlib.sha256(serialize_bytes).hexdigest(), 16) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, LLMCacheKey): + return False + return self.config == other.config + + def get_hash_bytes(self) -> bytes: + serialize_bytes = self.serialize() + return hashlib.sha256(serialize_bytes).digest() + + def to_dict(self) -> Dict: + return asdict(self.config) + + def serialize(self) -> bytes: + return self._serializer.serialize(self) + + def get_value(self) -> LLMCacheKeyData: + return self.config + + +class LLMCacheValue(CacheValue[LLMCacheValueData]): + def __init__(self, serializer: Serializer = None, **kwargs) -> None: + super().__init__() + self._serializer = serializer + self.value = LLMCacheValueData.from_dict(**kwargs) + + def to_dict(self) -> Dict: + return self.value.to_dict() + + def serialize(self) -> bytes: + return self._serializer.serialize(self) + + def get_value(self) -> LLMCacheValueData: + return self.value + + def __str__(self) -> str: + return f"vaue: {str(self.value)}" + + +class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]): + def __init__(self, cache_manager: CacheManager) -> None: + super().__init__() + self._cache_manager: CacheManager = cache_manager + + async def get( + self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None + ) -> Optional[LLMCacheValue]: + return await self._cache_manager.get(key, LLMCacheValue, cache_config) + + async def set( + self, + key: LLMCacheKey, + value: LLMCacheValue, + cache_config: Optional[CacheConfig] = None, + ) -> None: + return await self._cache_manager.set(key, value, cache_config) + + async def exists( + self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None + ) -> bool: + return await self.get(key, cache_config) is not None + + def new_key(self, **kwargs) -> LLMCacheKey: + return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs) + + def new_value(self, **kwargs) -> LLMCacheValue: + return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs) diff --git a/pilot/cache/manager.py b/pilot/cache/manager.py new file mode 100644 index 000000000..0e76df0b3 --- /dev/null +++ b/pilot/cache/manager.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +from typing import Optional, Type +import logging +from concurrent.futures import Executor +from pilot.cache.storage.base import CacheStorage, StorageItem +from pilot.cache.base import ( + K, + V, + CacheKey, + CacheValue, + CacheConfig, + Serializer, + Serializable, +) +from pilot.component import BaseComponent, ComponentType, SystemApp +from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async + +logger = logging.getLogger(__name__) + + +class CacheManager(BaseComponent, ABC): + name = ComponentType.MODEL_CACHE_MANAGER + + def __init__(self, system_app: SystemApp | None = None): + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + self.system_app = system_app + + @abstractmethod + async def set( + self, + key: CacheKey[K], + value: CacheValue[V], + cache_config: Optional[CacheConfig] = None, + ): + """Set cache""" + + @abstractmethod + async def get( + self, + key: CacheKey[K], + cls: Type[Serializable], + cache_config: Optional[CacheConfig] = None, + ) -> CacheValue[V]: + """Get cache with key""" + + @property + @abstractmethod + def serializer(self) -> Serializer: + """Get cache serializer""" + + +class LocalCacheManager(CacheManager): + def __init__( + self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage + ) -> None: + super().__init__(system_app) + self._serializer = serializer + self._storage = storage + + @property + def executor(self) -> Executor: + """Return executor to submit task""" + self._executor = self.system_app.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() + + async def set( + self, + key: CacheKey[K], + value: CacheValue[V], + cache_config: Optional[CacheConfig] = None, + ): + if self._storage.support_async(): + await self._storage.aset(key, value, cache_config) + else: + await blocking_func_to_async( + self.executor, self._storage.set, key, value, cache_config + ) + + async def get( + self, + key: CacheKey[K], + cls: Type[Serializable], + cache_config: Optional[CacheConfig] = None, + ) -> CacheValue[V]: + if self._storage.support_async(): + item_bytes = await self._storage.aget(key, cache_config) + else: + item_bytes = await blocking_func_to_async( + self.executor, self._storage.get, key, cache_config + ) + if not item_bytes: + return None + return self._serializer.deserialize(item_bytes.value_data, cls) + + @property + def serializer(self) -> Serializer: + return self._serializer + + +def initialize_cache( + system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str +): + from pilot.cache.protocal.json_protocal import JsonSerializer + from pilot.cache.storage.base import MemoryCacheStorage + + cache_storage = None + if storage_type == "disk": + try: + from pilot.cache.storage.disk.disk_storage import DiskCacheStorage + + cache_storage = DiskCacheStorage( + persist_dir, mem_table_buffer_mb=max_memory_mb + ) + except ImportError as e: + logger.warn( + f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}" + ) + cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb) + else: + cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb) + system_app.register( + LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage + ) diff --git a/pilot/cache/protocal/__init__.py b/pilot/cache/protocal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/cache/protocal/json_protocal.py b/pilot/cache/protocal/json_protocal.py new file mode 100644 index 000000000..6f73fef3f --- /dev/null +++ b/pilot/cache/protocal/json_protocal.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from typing import Dict, Type +import json + +from pilot.cache.base import Serializable, Serializer + +JSON_ENCODING = "utf-8" + + +class JsonSerializable(Serializable, ABC): + @abstractmethod + def to_dict(self) -> Dict: + """Return the dict of current serializable object""" + + def serialize(self) -> bytes: + """Convert the object into bytes for storage or transmission.""" + return json.dumps(self.to_dict(), ensure_ascii=False).encode(JSON_ENCODING) + + +class JsonSerializer(Serializer): + """The serializer abstract class for serializing cache keys and values.""" + + def serialize(self, obj: Serializable) -> bytes: + """Serialize a cache object. + + Args: + obj (Serializable): The object to serialize + """ + return json.dumps(obj.to_dict(), ensure_ascii=False).encode(JSON_ENCODING) + + def deserialize(self, data: bytes, cls: Type[Serializable]) -> Serializable: + """Deserialize data back into a cache object of the specified type. + + Args: + data (bytes): The byte array to deserialize + cls (Type[Serializable]): The type of current object + + Returns: + Serializable: The serializable object + """ + # Convert bytes back to JSON and then to the specified class + json_data = json.loads(data.decode(JSON_ENCODING)) + # Assume that the cls has an __init__ that accepts a dictionary + return cls(**json_data) diff --git a/pilot/cache/storage/__init__.py b/pilot/cache/storage/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/cache/storage/base.py b/pilot/cache/storage/base.py new file mode 100644 index 000000000..ea07bfacf --- /dev/null +++ b/pilot/cache/storage/base.py @@ -0,0 +1,252 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass +from collections import OrderedDict +import msgpack +import logging + +from pilot.cache.base import ( + K, + V, + CacheKey, + CacheValue, + CacheClient, + CacheConfig, + RetrievalPolicy, + CachePolicy, +) +from pilot.utils.memory_utils import _get_object_bytes + +logger = logging.getLogger(__name__) + + +@dataclass +class StorageItem: + """ + A class representing a storage item. + + This class encapsulates data related to a storage item, such as its length, + the hash of the key, and the data for both the key and value. + + Parameters: + length (int): The bytes length of the storage item. + key_hash (bytes): The hash value of the storage item's key. + key_data (bytes): The data of the storage item's key, represented in bytes. + value_data (bytes): The data of the storage item's value, also in bytes. + """ + + length: int # The bytes length of the storage item + key_hash: bytes # The hash value of the storage item's key + key_data: bytes # The data of the storage item's key + value_data: bytes # The data of the storage item's value + + @staticmethod + def build_from( + key_hash: bytes, key_data: bytes, value_data: bytes + ) -> "StorageItem": + length = ( + 32 + + _get_object_bytes(key_hash) + + _get_object_bytes(key_data) + + _get_object_bytes(value_data) + ) + return StorageItem( + length=length, key_hash=key_hash, key_data=key_data, value_data=value_data + ) + + @staticmethod + def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem": + key_hash = key.get_hash_bytes() + key_data = key.serialize() + value_data = value.serialize() + return StorageItem.build_from(key_hash, key_data, value_data) + + def serialize(self) -> bytes: + """Serialize the StorageItem into a byte stream using MessagePack. + + This method packs the object data into a dictionary, marking the + key_data and value_data fields as raw binary data to avoid re-serialization. + + Returns: + bytes: The serialized bytes. + """ + obj = { + "length": self.length, + "key_hash": msgpack.ExtType(1, self.key_hash), + "key_data": msgpack.ExtType(2, self.key_data), + "value_data": msgpack.ExtType(3, self.value_data), + } + return msgpack.packb(obj) + + @staticmethod + def deserialize(data: bytes) -> "StorageItem": + """Deserialize bytes back into a StorageItem using MessagePack. + + This extracts the fields from the MessagePack dict back into + a StorageItem object. + + Args: + data (bytes): Serialized bytes + + Returns: + StorageItem: Deserialized StorageItem object. + """ + obj = msgpack.unpackb(data) + key_hash = obj["key_hash"].data + key_data = obj["key_data"].data + value_data = obj["value_data"].data + + return StorageItem( + length=obj["length"], + key_hash=key_hash, + key_data=key_data, + value_data=value_data, + ) + + +class CacheStorage(ABC): + @abstractmethod + def check_config( + self, + cache_config: Optional[CacheConfig] = None, + raise_error: Optional[bool] = True, + ) -> bool: + """Check whether the CacheConfig is legal. + + Args: + cache_config (Optional[CacheConfig]): Cache config. + raise_error (Optional[bool]): Whether raise error if illegal. + + Returns: + ValueError: Error when raise_error is True and config is illegal. + """ + + def support_async(self) -> bool: + return False + + @abstractmethod + def get( + self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None + ) -> Optional[StorageItem]: + """Retrieve a storage item from the cache using the provided key. + + Args: + key (CacheKey[K]): The key to get cache + cache_config (Optional[CacheConfig]): Cache config + + Returns: + Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None. + """ + + async def aget( + self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None + ) -> Optional[StorageItem]: + """Retrieve a storage item from the cache using the provided key asynchronously. + + Args: + key (CacheKey[K]): The key to get cache + cache_config (Optional[CacheConfig]): Cache config + + Returns: + Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None. + """ + raise NotImplementedError + + @abstractmethod + def set( + self, + key: CacheKey[K], + value: CacheValue[V], + cache_config: Optional[CacheConfig] = None, + ) -> None: + """Set a value in the cache for the provided key asynchronously. + + Args: + key (CacheKey[K]): The key to set to cache + value (CacheValue[V]): The value to set to cache + cache_config (Optional[CacheConfig]): Cache config + """ + + async def aset( + self, + key: CacheKey[K], + value: CacheValue[V], + cache_config: Optional[CacheConfig] = None, + ) -> None: + """Set a value in the cache for the provided key asynchronously. + + Args: + key (CacheKey[K]): The key to set to cache + value (CacheValue[V]): The value to set to cache + cache_config (Optional[CacheConfig]): Cache config + """ + raise NotImplementedError + + +class MemoryCacheStorage(CacheStorage): + def __init__(self, max_memory_mb: int = 256): + self.cache = OrderedDict() + self.max_memory = max_memory_mb * 1024 * 1024 + self.current_memory_usage = 0 + + def check_config( + self, + cache_config: Optional[CacheConfig] = None, + raise_error: Optional[bool] = True, + ) -> bool: + if ( + cache_config + and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH + ): + if raise_error: + raise ValueError( + "MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy" + ) + return False + return True + + def get( + self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None + ) -> Optional[StorageItem]: + self.check_config(cache_config, raise_error=True) + # Exact match retrieval + key_hash = hash(key) + item: StorageItem = self.cache.get(key_hash) + logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}") + + if not item: + return None + # Move the item to the end of the OrderedDict to signify recent use. + self.cache.move_to_end(key_hash) + return item + + def set( + self, + key: CacheKey[K], + value: CacheValue[V], + cache_config: Optional[CacheConfig] = None, + ) -> None: + key_hash = hash(key) + item = StorageItem.build_from_kv(key, value) + # Calculate memory size of the new entry + new_entry_size = _get_object_bytes(item) + # Evict entries if necessary + while self.current_memory_usage + new_entry_size > self.max_memory: + self._apply_cache_policy(cache_config) + + # Store the item in the cache. + self.cache[key_hash] = item + self.current_memory_usage += new_entry_size + logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}") + + def exists( + self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None + ) -> bool: + return self.get(key, cache_config) is not None + + def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None): + # Remove the oldest/newest item based on the cache policy. + if cache_config and cache_config.cache_policy == CachePolicy.FIFO: + self.cache.popitem(last=False) + else: # Default is LRU + self.cache.popitem(last=True) diff --git a/pilot/cache/storage/disk/__init__.py b/pilot/cache/storage/disk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/cache/storage/disk/disk_storage.py b/pilot/cache/storage/disk/disk_storage.py new file mode 100644 index 000000000..04fb19c6e --- /dev/null +++ b/pilot/cache/storage/disk/disk_storage.py @@ -0,0 +1,93 @@ +from typing import Optional +import logging +from pilot.cache.base import ( + K, + V, + CacheKey, + CacheValue, + CacheConfig, + RetrievalPolicy, + CachePolicy, +) +from pilot.cache.storage.base import StorageItem, CacheStorage +from rocksdict import Rdict +from rocksdict import Rdict, Options, SliceTransform, PlainTableFactoryOptions + + +logger = logging.getLogger(__name__) + + +def db_options( + mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2 +): + opt = Options() + # create table + opt.create_if_missing(True) + # config to more jobs, default 2 + opt.set_max_background_jobs(background_threads) + # configure mem-table to a large value + opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024) + # opt.set_write_buffer_size(1024) + # opt.set_level_zero_file_num_compaction_trigger(4) + # configure l0 and l1 size, let them have the same size (1 GB) + # opt.set_max_bytes_for_level_base(0x40000000) + # 256 MB file size + # opt.set_target_file_size_base(0x10000000) + # use a smaller compaction multiplier + # opt.set_max_bytes_for_level_multiplier(4.0) + # use 8-byte prefix (2 ^ 64 is far enough for transaction counts) + # opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8)) + # set to plain-table + # opt.set_plain_table_factory(PlainTableFactoryOptions()) + return opt + + +class DiskCacheStorage(CacheStorage): + def __init__( + self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256 + ) -> None: + super().__init__() + self.db: Rdict = Rdict( + persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb) + ) + + def check_config( + self, + cache_config: Optional[CacheConfig] = None, + raise_error: Optional[bool] = True, + ) -> bool: + if ( + cache_config + and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH + ): + if raise_error: + raise ValueError( + "DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy" + ) + return False + return True + + def get( + self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None + ) -> Optional[StorageItem]: + self.check_config(cache_config, raise_error=True) + + # Exact match retrieval + key_hash = key.get_hash_bytes() + item_bytes = self.db.get(key_hash) + if not item_bytes: + return None + item = StorageItem.deserialize(item_bytes) + logger.debug(f"Read file cache, key: {key}, storage item: {item}") + return item + + def set( + self, + key: CacheKey[K], + value: CacheValue[V], + cache_config: Optional[CacheConfig] = None, + ) -> None: + item = StorageItem.build_from_kv(key, value) + key_hash = item.key_hash + self.db[key_hash] = item.serialize() + logger.debug(f"Save file cache, key: {key}, value: {value}") diff --git a/pilot/cache/storage/tests/__init__.py b/pilot/cache/storage/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/cache/storage/tests/test_storage.py b/pilot/cache/storage/tests/test_storage.py new file mode 100644 index 000000000..489873d08 --- /dev/null +++ b/pilot/cache/storage/tests/test_storage.py @@ -0,0 +1,53 @@ +import pytest +from ..base import StorageItem +from pilot.utils.memory_utils import _get_object_bytes + + +def test_build_from(): + key_hash = b"key_hash" + key_data = b"key_data" + value_data = b"value_data" + item = StorageItem.build_from(key_hash, key_data, value_data) + + assert item.key_hash == key_hash + assert item.key_data == key_data + assert item.value_data == value_data + assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes( + key_data + ) + _get_object_bytes(value_data) + + +def test_build_from_kv(): + class MockCacheKey: + def get_hash_bytes(self): + return b"key_hash" + + def serialize(self): + return b"key_data" + + class MockCacheValue: + def serialize(self): + return b"value_data" + + key = MockCacheKey() + value = MockCacheValue() + item = StorageItem.build_from_kv(key, value) + + assert item.key_hash == key.get_hash_bytes() + assert item.key_data == key.serialize() + assert item.value_data == value.serialize() + + +def test_serialize_deserialize(): + key_hash = b"key_hash" + key_data = b"key_data" + value_data = b"value_data" + item = StorageItem.build_from(key_hash, key_data, value_data) + + serialized = item.serialize() + deserialized = StorageItem.deserialize(serialized) + + assert deserialized.key_hash == item.key_hash + assert deserialized.key_data == item.key_data + assert deserialized.value_data == item.value_data + assert deserialized.length == item.length diff --git a/pilot/component.py b/pilot/component.py index 16013ee17..d79a8d395 100644 --- a/pilot/component.py +++ b/pilot/component.py @@ -48,6 +48,7 @@ class ComponentType(str, Enum): MODEL_CONTROLLER = "dbgpt_model_controller" MODEL_REGISTRY = "dbgpt_model_registry" MODEL_API_SERVER = "dbgpt_model_api_server" + MODEL_CACHE_MANAGER = "dbgpt_model_cache_manager" AGENT_HUB = "dbgpt_agent_hub" EXECUTOR_DEFAULT = "dbgpt_thread_pool_default" TRACER = "dbgpt_tracer" diff --git a/pilot/configs/config.py b/pilot/configs/config.py index b263b46c4..f93cd7b83 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -253,6 +253,19 @@ def __init__(self) -> None: ### Temporary configuration self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true" + self.MODEL_CACHE_ENABLE: bool = ( + os.getenv("MODEL_CACHE_ENABLE", "True").lower() == "true" + ) + self.MODEL_CACHE_STORAGE_TYPE: str = os.getenv( + "MODEL_CACHE_STORAGE_TYPE", "disk" + ) + self.MODEL_CACHE_MAX_MEMORY_MB: int = int( + os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256) + ) + self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv( + "MODEL_CACHE_STORAGE_DISK_DIR" + ) + def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" self.debug_mode = value diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 0e1fb3d40..cedfa8554 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -14,6 +14,7 @@ # nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") FONT_DIR = os.path.join(PILOT_PATH, "fonts") +MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache") current_directory = os.getcwd() diff --git a/pilot/model/cluster/base.py b/pilot/model/cluster/base.py index 9d22161b1..36c4779b8 100644 --- a/pilot/model/cluster/base.py +++ b/pilot/model/cluster/base.py @@ -17,6 +17,8 @@ class PromptRequest(BaseModel): temperature: float = None max_new_tokens: int = None stop: str = None + stop_token_ids: List[int] = [] + context_len: int = None echo: bool = True span_id: str = None diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 44a476f20..d6663fc9f 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -1,6 +1,6 @@ import os import logging -from typing import Dict, Iterator, List +from typing import Dict, Iterator, List, Optional from pilot.configs.model_config import get_device from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper @@ -60,7 +60,7 @@ def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: self.ml: ModelLoader = ModelLoader( model_path=self.model_path, model_name=self.model_name ) - # TODO read context len from model config + # Default model context len self.context_len = 2048 def model_param_class(self) -> ModelParameters: @@ -111,6 +111,12 @@ def start( self.model, self.tokenizer = self.ml.loader_with_params( model_params, self.llm_adapter ) + model_max_length = _parse_model_max_length(self.model, self.tokenizer) + if model_max_length: + logger.info( + f"Parse model max length {model_max_length} from model {self.model_name}." + ) + self.context_len = model_max_length def stop(self) -> None: if not self.model: @@ -138,9 +144,9 @@ def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: ) previous_response = "" - + context_len = params.get("context_len") or self.context_len for output in generate_stream_func( - self.model, self.tokenizer, params, get_device(), self.context_len + self.model, self.tokenizer, params, get_device(), context_len ): model_output, incremental_output, output_str = self._handle_output( output, previous_response, model_context @@ -183,9 +189,10 @@ async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]: ) previous_response = "" + context_len = params.get("context_len") or self.context_len async for output in generate_stream_func( - self.model, self.tokenizer, params, get_device(), self.context_len + self.model, self.tokenizer, params, get_device(), context_len ): model_output, incremental_output, output_str = self._handle_output( output, previous_response, model_context @@ -279,11 +286,27 @@ def _handle_exception(self, e): # Check if the exception is a torch.cuda.CudaError and if torch was imported. if _torch_imported and isinstance(e, torch.cuda.CudaError): model_output = ModelOutput( - text="**GPU OutOfMemory, Please Refresh.**", error_code=0 + text="**GPU OutOfMemory, Please Refresh.**", error_code=1 ) else: model_output = ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, + error_code=1, ) return model_output + + +def _parse_model_max_length(model, tokenizer) -> Optional[int]: + if not (tokenizer or model): + return None + try: + if tokenizer and hasattr(tokenizer, "model_max_length"): + return tokenizer.model_max_length + if model and hasattr(model, "config"): + model_config = model.config + if hasattr(model_config, "max_sequence_length"): + return model_config.max_sequence_length + if hasattr(model_config, "max_position_embeddings"): + return model_config.max_position_embeddings + except Exception: + return None diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index d67519f59..2dd402920 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -119,7 +119,10 @@ async def start(self): _async_heartbeat_sender(self.run_data, 20, self.send_heartbeat_func) ) for listener in self.start_listeners: - listener(self) + if asyncio.iscoroutinefunction(listener): + await listener(self) + else: + listener(self) async def stop(self, ignore_exception: bool = False): if not self.run_data.stop_event.is_set(): @@ -325,7 +328,7 @@ async def generate_stream( except Exception as e: yield ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, + error_code=1, ) return async with worker_run_data.semaphore: @@ -355,7 +358,7 @@ async def generate(self, params: Dict) -> ModelOutput: except Exception as e: return ModelOutput( text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, + error_code=1, ) async with worker_run_data.semaphore: if worker_run_data.worker.support_async(): @@ -996,6 +999,7 @@ def run_worker_manager( port: int = None, embedding_model_name: str = None, embedding_model_path: str = None, + start_listener: Callable[["WorkerManager"], None] = None, ): global worker_manager @@ -1029,6 +1033,8 @@ def run_worker_manager( worker_manager, embedding_model_name, embedding_model_path ) + worker_manager.after_start(start_listener) + if include_router: app.include_router(router, prefix="/api") diff --git a/pilot/model/cluster/worker/remote_manager.py b/pilot/model/cluster/worker/remote_manager.py index 61b608cc7..4047f428e 100644 --- a/pilot/model/cluster/worker/remote_manager.py +++ b/pilot/model/cluster/worker/remote_manager.py @@ -15,7 +15,10 @@ def __init__(self, model_registry: ModelRegistry = None) -> None: async def start(self): for listener in self.start_listeners: - listener(self) + if asyncio.iscoroutinefunction(listener): + await listener(self) + else: + listener(self) async def stop(self, ignore_exception: bool = False): pass diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index e2deeaa02..8fd242882 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -170,9 +170,12 @@ def model_adaptation( model_context["has_format_prompt"] = True params["prompt"] = new_prompt - # Overwrite model params: - params["stop"] = conv.stop_str - params["stop_token_ids"] = conv.stop_token_ids + custom_stop = params.get("stop") + custom_stop_token_ids = params.get("stop_token_ids") + + # Prefer the value passed in from the input parameter + params["stop"] = custom_stop or conv.stop_str + params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids return params, model_context diff --git a/pilot/model/operator/__init__.py b/pilot/model/operator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/operator/model_operator.py b/pilot/model/operator/model_operator.py new file mode 100644 index 000000000..6486e8373 --- /dev/null +++ b/pilot/model/operator/model_operator.py @@ -0,0 +1,300 @@ +from typing import AsyncIterator, Dict, List, Union +import logging +from pilot.awel import ( + BranchFunc, + StreamifyAbsOperator, + BranchOperator, + MapOperator, + TransformStreamAbsOperator, +) +from pilot.awel.operator.base import BaseOperator +from pilot.model.base import ModelOutput +from pilot.model.cluster import WorkerManager +from pilot.cache import LLMCacheClient, CacheManager, LLMCacheKey, LLMCacheValue + +logger = logging.getLogger(__name__) + +_LLM_MODEL_INPUT_VALUE_KEY = "llm_model_input_value" +_LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache" + + +class ModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]): + """Operator for streaming processing of model outputs. + + Args: + worker_manager (WorkerManager): The manager that handles worker processes for model inference. + **kwargs: Additional keyword arguments. + + Methods: + streamify: Asynchronously processes a stream of inputs, yielding model outputs. + """ + + def __init__(self, worker_manager: WorkerManager, **kwargs) -> None: + super().__init__(**kwargs) + self.worker_manager = worker_manager + + async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]: + """Process inputs as a stream and yield model outputs. + + Args: + input_value (Dict): The input value for the model. + + Returns: + AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs. + """ + async for out in self.worker_manager.generate_stream(input_value): + yield out + + +class ModelOperator(MapOperator[Dict, ModelOutput]): + """Operator for map-based processing of model outputs. + + Args: + worker_manager (WorkerManager): Manager for handling worker processes. + **kwargs: Additional keyword arguments. + + Methods: + map: Asynchronously processes a single input and returns the model output. + """ + + def __init__(self, worker_manager: WorkerManager, **kwargs) -> None: + self.worker_manager = worker_manager + super().__init__(**kwargs) + + async def map(self, input_value: Dict) -> ModelOutput: + """Process a single input and return the model output. + + Args: + input_value (Dict): The input value for the model. + + Returns: + ModelOutput: The output from the model. + """ + return await self.worker_manager.generate(input_value) + + +class CachedModelStreamOperator(StreamifyAbsOperator[Dict, ModelOutput]): + """Operator for streaming processing of model outputs with caching. + + Args: + cache_manager (CacheManager): The cache manager to handle caching operations. + **kwargs: Additional keyword arguments. + + Methods: + streamify: Processes a stream of inputs with cache support, yielding model outputs. + """ + + def __init__(self, cache_manager: CacheManager, **kwargs) -> None: + super().__init__(**kwargs) + self._cache_manager = cache_manager + self._client = LLMCacheClient(cache_manager) + + async def streamify(self, input_value: Dict) -> AsyncIterator[ModelOutput]: + """Process inputs as a stream with cache support and yield model outputs. + + Args: + input_value (Dict): The input value for the model. + + Returns: + AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs. + """ + cache_dict = _parse_cache_key_dict(input_value) + llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict) + llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key) + logger.info(f"llm_cache_value: {llm_cache_value}") + for out in llm_cache_value.get_value().output: + yield out + + +class CachedModelOperator(MapOperator[Dict, ModelOutput]): + """Operator for map-based processing of model outputs with caching. + + Args: + cache_manager (CacheManager): Manager for caching operations. + **kwargs: Additional keyword arguments. + + Methods: + map: Processes a single input with cache support and returns the model output. + """ + + def __init__(self, cache_manager: CacheManager, **kwargs) -> None: + super().__init__(**kwargs) + self._cache_manager = cache_manager + self._client = LLMCacheClient(cache_manager) + + async def map(self, input_value: Dict) -> ModelOutput: + """Process a single input with cache support and return the model output. + + Args: + input_value (Dict): The input value for the model. + + Returns: + ModelOutput: The output from the model. + """ + cache_dict = _parse_cache_key_dict(input_value) + llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict) + llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key) + logger.info(f"llm_cache_value: {llm_cache_value}") + return llm_cache_value.get_value().output + + +class ModelCacheBranchOperator(BranchOperator[Dict, Dict]): + """ + A branch operator that decides whether to use cached data or to process data using the model. + + Args: + cache_manager (CacheManager): The cache manager for managing cache operations. + model_task_name (str): The name of the task to process data using the model. + cache_task_name (str): The name of the task to process data using the cache. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + cache_manager: CacheManager, + model_task_name: str, + cache_task_name: str, + **kwargs, + ): + super().__init__(branches=None, **kwargs) + self._cache_manager = cache_manager + self._client = LLMCacheClient(cache_manager) + self._model_task_name = model_task_name + self._cache_task_name = cache_task_name + + async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]: + """Defines branch logic based on cache availability. + + Returns: + Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names. + """ + + async def check_cache_true(input_value: Dict) -> bool: + # Check if the cache contains the result for the given input + cache_dict = _parse_cache_key_dict(input_value) + cache_key: LLMCacheKey = self._client.new_key(**cache_dict) + cache_value = await self._client.get(cache_key) + logger.debug( + f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}" + ) + await self.current_dag_context.save_to_share_data( + _LLM_MODEL_INPUT_VALUE_KEY, cache_key + ) + return True if cache_value else False + + async def check_cache_false(input_value: Dict): + # Inverse of check_cache_true + return not await check_cache_true(input_value) + + return { + check_cache_true: self._cache_task_name, + check_cache_false: self._model_task_name, + } + + +class ModelStreamSaveCacheOperator( + TransformStreamAbsOperator[ModelOutput, ModelOutput] +): + """An operator to save the stream of model outputs to cache. + + Args: + cache_manager (CacheManager): The cache manager for handling cache operations. + **kwargs: Additional keyword arguments. + """ + + def __init__(self, cache_manager: CacheManager, **kwargs): + self._cache_manager = cache_manager + self._client = LLMCacheClient(cache_manager) + super().__init__(**kwargs) + + async def transform_stream( + self, input_value: AsyncIterator[ModelOutput] + ) -> AsyncIterator[ModelOutput]: + """Transforms the input stream by saving the outputs to cache. + + Args: + input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs. + + Returns: + AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache. + """ + llm_cache_key: LLMCacheKey = None + outputs = [] + async for out in input_value: + if not llm_cache_key: + llm_cache_key = await self.current_dag_context.get_share_data( + _LLM_MODEL_INPUT_VALUE_KEY + ) + outputs.append(out) + yield out + if llm_cache_key and _is_success_model_output(outputs): + llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs) + await self._client.set(llm_cache_key, llm_cache_value) + + +class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]): + """An operator to save a single model output to cache. + + Args: + cache_manager (CacheManager): The cache manager for handling cache operations. + **kwargs: Additional keyword arguments. + """ + + def __init__(self, cache_manager: CacheManager, **kwargs): + self._cache_manager = cache_manager + self._client = LLMCacheClient(cache_manager) + super().__init__(**kwargs) + + async def map(self, input_value: ModelOutput) -> ModelOutput: + """Saves a single model output to cache and returns it. + + Args: + input_value (ModelOutput): The output from the model to be cached. + + Returns: + ModelOutput: The same input model output. + """ + llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data( + _LLM_MODEL_INPUT_VALUE_KEY + ) + llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value) + if llm_cache_key and _is_success_model_output(input_value): + await self._client.set(llm_cache_key, llm_cache_value) + return input_value + + +def _parse_cache_key_dict(input_value: Dict) -> Dict: + """Parses and extracts relevant fields from input to form a cache key dictionary. + + Args: + input_value (Dict): The input dictionary containing model and prompt parameters. + + Returns: + Dict: A dictionary used for generating cache keys. + """ + prompt: str = input_value.get("prompt") + if prompt: + prompt = prompt.strip() + return { + "prompt": prompt, + "model_name": input_value.get("model"), + "temperature": input_value.get("temperature"), + "max_new_tokens": input_value.get("max_new_tokens"), + "top_p": input_value.get("top_p", "1.0"), + # TODO pass model_type + "model_type": input_value.get("model_type", "huggingface"), + } + + +def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool: + if not out: + return False + if isinstance(out, list): + # check last model output + out = out[-1] + error_code = 0 + if isinstance(out, ModelOutput): + error_code = out.error_code + else: + error_code = int(out.get("error_code", 0)) + return error_code == 0 diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 24ec1c928..9a19f3255 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -16,6 +16,8 @@ from pilot.utils.tracer import root_tracer, trace from pydantic import Extra from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory +from pilot.awel import BaseOperator, SimpleCallDataInputSource, InputOperator, DAG +from pilot.model.operator.model_operator import ModelOperator, ModelStreamOperator logger = logging.getLogger(__name__) headers = {"User-Agent": "dbgpt Client"} @@ -88,6 +90,11 @@ def __init__(self, chat_param: Dict): ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ).create() + self._model_operator: BaseOperator = _build_model_operator() + self._model_stream_operator: BaseOperator = _build_model_operator( + is_stream=True, dag_name="llm_stream_model_dag" + ) + class Config: """Configuration for this pydantic object.""" @@ -166,7 +173,7 @@ async def __call_base(self): "messages": llm_messages, "temperature": float(self.prompt_template.temperature), "max_new_tokens": int(self.prompt_template.max_new_tokens), - "stop": self.prompt_template.sep, + # "stop": self.prompt_template.sep, "echo": self.llm_echo, } return payload @@ -204,12 +211,9 @@ async def stream_call(self): ) payload["span_id"] = span.span_id try: - from pilot.model.cluster import WorkerManagerFactory - - worker_manager = CFG.SYSTEM_APP.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - async for output in worker_manager.generate_stream(payload): + async for output in await self._model_stream_operator.call_stream( + call_data={"data": payload} + ): ### Plug-in research in result generation msg = self.prompt_template.output_parser.parse_model_stream_resp_ex( output, self.skip_echo_len @@ -240,14 +244,10 @@ async def nostream_call(self): ) payload["span_id"] = span.span_id try: - from pilot.model.cluster import WorkerManagerFactory - - worker_manager = CFG.SYSTEM_APP.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"): - model_output = await worker_manager.generate(payload) + model_output = await self._model_operator.call( + call_data={"data": payload} + ) ### output parse ai_response_text = ( @@ -307,14 +307,7 @@ async def get_llm_response(self): logger.info(f"Request: \n{payload}") ai_response_text = "" try: - from pilot.model.cluster import WorkerManagerFactory - - worker_manager = CFG.SYSTEM_APP.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - - model_output = await worker_manager.generate(payload) - + model_output = await self._model_operator.call(call_data={"data": payload}) ### output parse ai_response_text = ( self.prompt_template.output_parser.parse_model_nostream_resp( @@ -568,3 +561,88 @@ def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any: ) else: return prompt_define_response + + +def _build_model_operator( + is_stream: bool = False, dag_name: str = "llm_model_dag" +) -> BaseOperator: + """Builds and returns a model processing workflow (DAG) operator. + + This function constructs a Directed Acyclic Graph (DAG) for processing data using a model. + It includes caching and branching logic to either fetch results from a cache or process + data using the model. It supports both streaming and non-streaming modes. + + .. code-block:: python + input_node >> cache_check_branch_node + cache_check_branch_node >> model_node >> save_cached_node >> join_node + cache_check_branch_node >> cached_node >> join_node + + equivalent to:: + + -> model_node -> save_cached_node -> + / \ + input_node -> cache_check_branch_node ---> join_node + \ / + -> cached_node ------------------- -> + + Args: + is_stream (bool): Flag to determine if the operator should process data in streaming mode. + dag_name (str): Name of the DAG. + + Returns: + BaseOperator: The final operator in the constructed DAG, typically a join node. + """ + from pilot.model.cluster import WorkerManagerFactory + from pilot.awel import JoinOperator + from pilot.model.operator.model_operator import ( + ModelCacheBranchOperator, + CachedModelStreamOperator, + CachedModelOperator, + ModelSaveCacheOperator, + ModelStreamSaveCacheOperator, + ) + from pilot.cache import CacheManager + + # Fetch worker and cache managers from the system configuration + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + cache_manager: CacheManager = CFG.SYSTEM_APP.get_component( + ComponentType.MODEL_CACHE_MANAGER, CacheManager + ) + # Define task names for the model and cache nodes + model_task_name = "llm_model_node" + cache_task_name = "llm_model_cache_node" + + with DAG(dag_name): + # Create an input node + input_node = InputOperator(SimpleCallDataInputSource()) + # Determine if the workflow should operate in streaming mode + if is_stream: + model_node = ModelStreamOperator(worker_manager, task_name=model_task_name) + cached_node = CachedModelStreamOperator( + cache_manager, task_name=cache_task_name + ) + save_cached_node = ModelStreamSaveCacheOperator(cache_manager) + else: + model_node = ModelOperator(worker_manager, task_name=model_task_name) + cached_node = CachedModelOperator(cache_manager, task_name=cache_task_name) + save_cached_node = ModelSaveCacheOperator(cache_manager) + + # Create a branch node to decide between fetching from cache or processing with the model + cache_check_branch_node = ModelCacheBranchOperator( + cache_manager, + model_task_name="llm_model_node", + cache_task_name="llm_model_cache_node", + ) + # Create a join node to merge outputs from the model and cache nodes, just keep the fist not empty output + join_node = JoinOperator( + combine_function=lambda model_out, cache_out: cache_out or model_out + ) + + # Define the workflow structure using the >> operator + input_node >> cache_check_branch_node + cache_check_branch_node >> model_node >> save_cached_node >> join_node + cache_check_branch_node >> cached_node >> join_node + + return join_node diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py index d127120cc..58269385b 100644 --- a/pilot/server/component_configs.py +++ b/pilot/server/component_configs.py @@ -5,6 +5,8 @@ import os from pilot.component import ComponentType, SystemApp +from pilot.configs.config import Config +from pilot.configs.model_config import MODEL_DISK_CACHE_DIR from pilot.utils.executor_utils import DefaultExecutorFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.server.base import WebWerverParameters @@ -15,6 +17,8 @@ logger = logging.getLogger(__name__) +CFG = Config() + def initialize_components( param: WebWerverParameters, @@ -40,6 +44,7 @@ def initialize_components( _initialize_embedding_model( param, system_app, embedding_model_name, embedding_model_path ) + _initialize_model_cache(system_app) def _initialize_embedding_model( @@ -131,3 +136,16 @@ def _load_model(self) -> "Embeddings": loader = EmbeddingLoader() # Ignore model_name args return loader.load(self._default_model_name, model_params) + + +def _initialize_model_cache(system_app: SystemApp): + from pilot.cache import initialize_cache + + if not CFG.MODEL_CACHE_ENABLE: + logger.info("Model cache is not enable") + return + + storage_type = CFG.MODEL_CACHE_STORAGE_TYPE or "disk" + max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256 + persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR + initialize_cache(system_app, storage_type, max_memory_mb, persist_dir) diff --git a/pilot/utils/memory_utils.py b/pilot/utils/memory_utils.py new file mode 100644 index 000000000..cb0427c08 --- /dev/null +++ b/pilot/utils/memory_utils.py @@ -0,0 +1,11 @@ +from typing import Any +from pympler import asizeof + + +def _get_object_bytes(obj: Any) -> int: + """Get the bytes of a object in memory + + Args: + obj (Any): The object to return the bytes + """ + return asizeof.asizeof(obj) diff --git a/setup.py b/setup.py index a6bea38f1..cdc95dccc 100644 --- a/setup.py +++ b/setup.py @@ -319,6 +319,8 @@ def core_requires(): "alembic==1.12.0", # for excel "openpyxl", + # for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit. + "pympler", ] @@ -410,6 +412,13 @@ def vllm_requires(): setup_spec.extras["vllm"] = ["vllm"] +def cache_requires(): + """ + pip install "db-gpt[cache]" + """ + setup_spec.extras["cache"] = ["rocksdict", "msgpack"] + + # def chat_scene(): # setup_spec.extras["chat"] = [ # "" @@ -460,6 +469,7 @@ def init_install_requires(): openai_requires() gpt4all_requires() vllm_requires() +cache_requires() # must be last default_requires()