Skip to content

Commit

Permalink
feat(core): More AWEL operators and new prompt manager API (eosphoros…
Browse files Browse the repository at this point in the history
…-ai#972)

Co-authored-by: csunny <[email protected]>
  • Loading branch information
2 people authored and Hopshine committed Sep 10, 2024
1 parent e2a5a81 commit c8c7fb0
Show file tree
Hide file tree
Showing 46 changed files with 2,556 additions and 294 deletions.
8 changes: 6 additions & 2 deletions assets/schema/knowledge_management.sql
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,18 @@ CREATE TABLE IF NOT EXISTS `prompt_manage`
`chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene',
`sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene',
`prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private',
`prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name',
`content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content',
`input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))',
`model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)',
`prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)',
`prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)',
`user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name',
`sys_code` varchar(128) DEFAULT NULL COMMENT 'System code',
`gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`),
UNIQUE KEY `prompt_name_uiq` (`prompt_name`),
UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`),
KEY `gmt_created_idx` (`gmt_created`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table';

Expand Down
19 changes: 10 additions & 9 deletions dbgpt/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from dbgpt.core.interface.llm import (
ModelInferenceMetrics,
ModelRequest,
ModelRequestContext,
ModelOutput,
LLMClient,
LLMOperator,
StreamingLLMOperator,
RequestBuildOperator,
ModelMetadata,
)
from dbgpt.core.interface.message import (
Expand All @@ -17,7 +15,11 @@
ConversationIdentifier,
MessageIdentifier,
)
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
from dbgpt.core.interface.prompt import (
PromptTemplate,
PromptManager,
StoragePromptTemplate,
)
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.core.interface.cache import (
Expand All @@ -38,25 +40,24 @@
StorageError,
)


__ALL__ = [
"ModelInferenceMetrics",
"ModelRequest",
"ModelRequestContext",
"ModelOutput",
"Operator",
"RequestBuildOperator",
"ModelMetadata",
"ModelMessage",
"LLMClient",
"LLMOperator",
"StreamingLLMOperator",
"ModelMessageRoleType",
"OnceConversation",
"StorageConversation",
"MessageStorageItem",
"ConversationIdentifier",
"MessageIdentifier",
"PromptTemplate",
"PromptTemplateOperator",
"PromptManager",
"StoragePromptTemplate",
"BaseOutputParser",
"SQLOutputParser",
"Serializable",
Expand Down
28 changes: 28 additions & 0 deletions dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from typing import List, Optional
from dbgpt.component import SystemApp

from .dag.base import DAGContext, DAG
Expand Down Expand Up @@ -68,6 +69,7 @@
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"HttpTrigger",
"setup_dev_environment",
]


Expand All @@ -85,3 +87,29 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str):
initialize_runner(DefaultWorkflowRunner())
# Load all dags
dag_manager.load_dags()


def setup_dev_environment(
dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555
) -> None:
"""Setup a development environment for AWEL.
Just using in development environment, not production environment.
"""
import uvicorn
from fastapi import FastAPI
from dbgpt.component import SystemApp
from .trigger.trigger_manager import DefaultTriggerManager
from .dag.base import DAGVar

app = FastAPI()
system_app = SystemApp(app)
DAGVar.set_current_system_app(system_app)
trigger_manager = DefaultTriggerManager()
system_app.register_instance(trigger_manager)

for dag in dags:
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
uvicorn.run(app, host=host, port=port)
163 changes: 153 additions & 10 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from dbgpt.component import SystemApp
from ..resource.base import ResourceGroup
from ..task.base import TaskContext
from ..task.base import TaskContext, TaskOutput

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -168,7 +168,19 @@ def set_executor(cls, executor: Executor) -> None:
cls._executor = executor


class DAGNode(DependencyMixin, ABC):
class DAGLifecycle:
"""The lifecycle of DAG"""

async def before_dag_run(self):
"""The callback before DAG run"""
pass

async def after_dag_end(self):
"""The callback after DAG end"""
pass


class DAGNode(DAGLifecycle, DependencyMixin, ABC):
resource_group: Optional[ResourceGroup] = None
"""The resource group of current DAGNode"""

Expand All @@ -179,7 +191,7 @@ def __init__(
node_name: Optional[str] = None,
system_app: Optional[SystemApp] = None,
executor: Optional[Executor] = None,
**kwargs
**kwargs,
) -> None:
super().__init__()
self._upstream: List["DAGNode"] = []
Expand All @@ -198,10 +210,23 @@ def __init__(
def node_id(self) -> str:
return self._node_id

@property
@abstractmethod
def dev_mode(self) -> bool:
"""Whether current DAGNode is in dev mode"""

@property
def system_app(self) -> SystemApp:
return self._system_app

def set_system_app(self, system_app: SystemApp) -> None:
"""Set system app for current DAGNode
Args:
system_app (SystemApp): The system app
"""
self._system_app = system_app

def set_node_id(self, node_id: str) -> None:
self._node_id = node_id

Expand Down Expand Up @@ -274,11 +299,41 @@ def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> Non
node._upstream.append(self)


def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"


class DAGContext:
def __init__(self, streaming_call: bool = False) -> None:
"""The context of current DAG, created when the DAG is running
Every DAG has been triggered will create a new DAGContext.
"""

def __init__(
self,
streaming_call: bool = False,
node_to_outputs: Dict[str, TaskContext] = None,
node_name_to_ids: Dict[str, str] = None,
) -> None:
if not node_to_outputs:
node_to_outputs = {}
if not node_name_to_ids:
node_name_to_ids = {}
self._streaming_call = streaming_call
self._curr_task_ctx = None
self._share_data: Dict[str, Any] = {}
self._node_to_outputs = node_to_outputs
self._node_name_to_ids = node_name_to_ids

@property
def _task_outputs(self) -> Dict[str, TaskContext]:
"""The task outputs of current DAG
Just use for internal for now.
Returns:
Dict[str, TaskContext]: The task outputs of current DAG
"""
return self._node_to_outputs

@property
def current_task_context(self) -> TaskContext:
Expand All @@ -292,24 +347,90 @@ def streaming_call(self) -> bool:
def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None:
self._curr_task_ctx = _curr_task_ctx

async def get_share_data(self, key: str) -> Any:
def get_task_output(self, task_name: str) -> TaskOutput:
"""Get the task output by task name
Args:
task_name (str): The task name
Returns:
TaskOutput: The task output
"""
if task_name is None:
raise ValueError("task_name can't be None")
node_id = self._node_name_to_ids.get(task_name)
if node_id:
raise ValueError(f"Task name {task_name} not exists in DAG")
return self._task_outputs.get(node_id).task_output

async def get_from_share_data(self, key: str) -> Any:
return self._share_data.get(key)

async def save_to_share_data(self, key: str, data: Any) -> None:
async def save_to_share_data(
self, key: str, data: Any, overwrite: Optional[str] = None
) -> None:
if key in self._share_data and not overwrite:
raise ValueError(f"Share data key {key} already exists")
self._share_data[key] = data

async def get_task_share_data(self, task_name: str, key: str) -> Any:
"""Get share data by task name and key
Args:
task_name (str): The task name
key (str): The share data key
Returns:
Any: The share data
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
return self.get_from_share_data(_build_task_key(task_name, key))

async def save_task_share_data(
self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None
) -> None:
"""Save share data by task name and key
Args:
task_name (str): The task name
key (str): The share data key
data (Any): The share data
overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists.
Defaults to None.
Raises:
ValueError: If the share data key already exists and overwrite is not True
"""
if task_name is None:
raise ValueError("task_name can't be None")
if key is None:
raise ValueError("key can't be None")
await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite)


class DAG:
def __init__(
self, dag_id: str, resource_group: Optional[ResourceGroup] = None
) -> None:
self._dag_id = dag_id
self.node_map: Dict[str, DAGNode] = {}
self._root_nodes: Set[DAGNode] = None
self._leaf_nodes: Set[DAGNode] = None
self._trigger_nodes: Set[DAGNode] = None
self.node_name_to_node: Dict[str, DAGNode] = {}
self._root_nodes: List[DAGNode] = None
self._leaf_nodes: List[DAGNode] = None
self._trigger_nodes: List[DAGNode] = None

def _append_node(self, node: DAGNode) -> None:
if node.node_id in self.node_map:
return
if node.node_name:
if node.node_name in self.node_name_to_node:
raise ValueError(
f"Node name {node.node_name} already exists in DAG {self.dag_id}"
)
self.node_name_to_node[node.node_name] = node
self.node_map[node.node_id] = node
# clear cached nodes
self._root_nodes = None
Expand All @@ -336,22 +457,44 @@ def _build(self) -> None:

@property
def root_nodes(self) -> List[DAGNode]:
"""The root nodes of current DAG
Returns:
List[DAGNode]: The root nodes of current DAG, no repeat
"""
if not self._root_nodes:
self._build()
return self._root_nodes

@property
def leaf_nodes(self) -> List[DAGNode]:
"""The leaf nodes of current DAG
Returns:
List[DAGNode]: The leaf nodes of current DAG, no repeat
"""
if not self._leaf_nodes:
self._build()
return self._leaf_nodes

@property
def trigger_nodes(self):
def trigger_nodes(self) -> List[DAGNode]:
"""The trigger nodes of current DAG
Returns:
List[DAGNode]: The trigger nodes of current DAG, no repeat
"""
if not self._trigger_nodes:
self._build()
return self._trigger_nodes

async def _after_dag_end(self) -> None:
"""The callback after DAG end"""
tasks = []
for node in self.node_map.values():
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)

def __enter__(self):
DAGVar.enter_dag(self)
return self
Expand Down
10 changes: 10 additions & 0 deletions dbgpt/core/awel/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ def __init__(
def current_dag_context(self) -> DAGContext:
return self._dag_ctx

@property
def dev_mode(self) -> bool:
"""Whether the operator is in dev mode.
In production mode, the default runner is not None.
Returns:
bool: Whether the operator is in dev mode. True if the default runner is None.
"""
return default_runner is None

async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
if not self.node_id:
raise ValueError(f"The DAG Node ID can't be empty, current node {self}")
Expand Down
Loading

0 comments on commit c8c7fb0

Please sign in to comment.