diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql index 53f6747dc..b0cf2178a 100644 --- a/assets/schema/knowledge_management.sql +++ b/assets/schema/knowledge_management.sql @@ -169,14 +169,18 @@ CREATE TABLE IF NOT EXISTS `prompt_manage` `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', - `prompt_name` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', + `prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', + `input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', + `model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', + `prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', + `prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', PRIMARY KEY (`id`), - UNIQUE KEY `prompt_name_uiq` (`prompt_name`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`), KEY `gmt_created_idx` (`gmt_created`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index e740b1294..4be0c1f17 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -1,11 +1,9 @@ from dbgpt.core.interface.llm import ( ModelInferenceMetrics, ModelRequest, + ModelRequestContext, ModelOutput, LLMClient, - LLMOperator, - StreamingLLMOperator, - RequestBuildOperator, ModelMetadata, ) from dbgpt.core.interface.message import ( @@ -17,7 +15,11 @@ ConversationIdentifier, MessageIdentifier, ) -from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator +from dbgpt.core.interface.prompt import ( + PromptTemplate, + PromptManager, + StoragePromptTemplate, +) from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser from dbgpt.core.interface.serialization import Serializable, Serializer from dbgpt.core.interface.cache import ( @@ -38,17 +40,15 @@ StorageError, ) + __ALL__ = [ "ModelInferenceMetrics", "ModelRequest", + "ModelRequestContext", "ModelOutput", - "Operator", - "RequestBuildOperator", "ModelMetadata", "ModelMessage", "LLMClient", - "LLMOperator", - "StreamingLLMOperator", "ModelMessageRoleType", "OnceConversation", "StorageConversation", @@ -56,7 +56,8 @@ "ConversationIdentifier", "MessageIdentifier", "PromptTemplate", - "PromptTemplateOperator", + "PromptManager", + "StoragePromptTemplate", "BaseOutputParser", "SQLOutputParser", "Serializable", diff --git a/dbgpt/core/awel/__init__.py b/dbgpt/core/awel/__init__.py index c3361e883..9331d2621 100644 --- a/dbgpt/core/awel/__init__.py +++ b/dbgpt/core/awel/__init__.py @@ -7,6 +7,7 @@ """ +from typing import List, Optional from dbgpt.component import SystemApp from .dag.base import DAGContext, DAG @@ -68,6 +69,7 @@ "UnstreamifyAbsOperator", "TransformStreamAbsOperator", "HttpTrigger", + "setup_dev_environment", ] @@ -85,3 +87,29 @@ def initialize_awel(system_app: SystemApp, dag_filepath: str): initialize_runner(DefaultWorkflowRunner()) # Load all dags dag_manager.load_dags() + + +def setup_dev_environment( + dags: List[DAG], host: Optional[str] = "0.0.0.0", port: Optional[int] = 5555 +) -> None: + """Setup a development environment for AWEL. + + Just using in development environment, not production environment. + """ + import uvicorn + from fastapi import FastAPI + from dbgpt.component import SystemApp + from .trigger.trigger_manager import DefaultTriggerManager + from .dag.base import DAGVar + + app = FastAPI() + system_app = SystemApp(app) + DAGVar.set_current_system_app(system_app) + trigger_manager = DefaultTriggerManager() + system_app.register_instance(trigger_manager) + + for dag in dags: + for trigger in dag.trigger_nodes: + trigger_manager.register_trigger(trigger) + trigger_manager.after_register() + uvicorn.run(app, host=host, port=port) diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index e459215d6..5f182a97b 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -11,7 +11,7 @@ from dbgpt.component import SystemApp from ..resource.base import ResourceGroup -from ..task.base import TaskContext +from ..task.base import TaskContext, TaskOutput logger = logging.getLogger(__name__) @@ -168,7 +168,19 @@ def set_executor(cls, executor: Executor) -> None: cls._executor = executor -class DAGNode(DependencyMixin, ABC): +class DAGLifecycle: + """The lifecycle of DAG""" + + async def before_dag_run(self): + """The callback before DAG run""" + pass + + async def after_dag_end(self): + """The callback after DAG end""" + pass + + +class DAGNode(DAGLifecycle, DependencyMixin, ABC): resource_group: Optional[ResourceGroup] = None """The resource group of current DAGNode""" @@ -179,7 +191,7 @@ def __init__( node_name: Optional[str] = None, system_app: Optional[SystemApp] = None, executor: Optional[Executor] = None, - **kwargs + **kwargs, ) -> None: super().__init__() self._upstream: List["DAGNode"] = [] @@ -198,10 +210,23 @@ def __init__( def node_id(self) -> str: return self._node_id + @property + @abstractmethod + def dev_mode(self) -> bool: + """Whether current DAGNode is in dev mode""" + @property def system_app(self) -> SystemApp: return self._system_app + def set_system_app(self, system_app: SystemApp) -> None: + """Set system app for current DAGNode + + Args: + system_app (SystemApp): The system app + """ + self._system_app = system_app + def set_node_id(self, node_id: str) -> None: self._node_id = node_id @@ -274,11 +299,41 @@ def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> Non node._upstream.append(self) +def _build_task_key(task_name: str, key: str) -> str: + return f"{task_name}___$$$$$$___{key}" + + class DAGContext: - def __init__(self, streaming_call: bool = False) -> None: + """The context of current DAG, created when the DAG is running + + Every DAG has been triggered will create a new DAGContext. + """ + + def __init__( + self, + streaming_call: bool = False, + node_to_outputs: Dict[str, TaskContext] = None, + node_name_to_ids: Dict[str, str] = None, + ) -> None: + if not node_to_outputs: + node_to_outputs = {} + if not node_name_to_ids: + node_name_to_ids = {} self._streaming_call = streaming_call self._curr_task_ctx = None self._share_data: Dict[str, Any] = {} + self._node_to_outputs = node_to_outputs + self._node_name_to_ids = node_name_to_ids + + @property + def _task_outputs(self) -> Dict[str, TaskContext]: + """The task outputs of current DAG + + Just use for internal for now. + Returns: + Dict[str, TaskContext]: The task outputs of current DAG + """ + return self._node_to_outputs @property def current_task_context(self) -> TaskContext: @@ -292,12 +347,69 @@ def streaming_call(self) -> bool: def set_current_task_context(self, _curr_task_ctx: TaskContext) -> None: self._curr_task_ctx = _curr_task_ctx - async def get_share_data(self, key: str) -> Any: + def get_task_output(self, task_name: str) -> TaskOutput: + """Get the task output by task name + + Args: + task_name (str): The task name + + Returns: + TaskOutput: The task output + """ + if task_name is None: + raise ValueError("task_name can't be None") + node_id = self._node_name_to_ids.get(task_name) + if node_id: + raise ValueError(f"Task name {task_name} not exists in DAG") + return self._task_outputs.get(node_id).task_output + + async def get_from_share_data(self, key: str) -> Any: return self._share_data.get(key) - async def save_to_share_data(self, key: str, data: Any) -> None: + async def save_to_share_data( + self, key: str, data: Any, overwrite: Optional[str] = None + ) -> None: + if key in self._share_data and not overwrite: + raise ValueError(f"Share data key {key} already exists") self._share_data[key] = data + async def get_task_share_data(self, task_name: str, key: str) -> Any: + """Get share data by task name and key + + Args: + task_name (str): The task name + key (str): The share data key + + Returns: + Any: The share data + """ + if task_name is None: + raise ValueError("task_name can't be None") + if key is None: + raise ValueError("key can't be None") + return self.get_from_share_data(_build_task_key(task_name, key)) + + async def save_task_share_data( + self, task_name: str, key: str, data: Any, overwrite: Optional[str] = None + ) -> None: + """Save share data by task name and key + + Args: + task_name (str): The task name + key (str): The share data key + data (Any): The share data + overwrite (Optional[str], optional): Whether overwrite the share data if the key already exists. + Defaults to None. + + Raises: + ValueError: If the share data key already exists and overwrite is not True + """ + if task_name is None: + raise ValueError("task_name can't be None") + if key is None: + raise ValueError("key can't be None") + await self.save_to_share_data(_build_task_key(task_name, key), data, overwrite) + class DAG: def __init__( @@ -305,11 +417,20 @@ def __init__( ) -> None: self._dag_id = dag_id self.node_map: Dict[str, DAGNode] = {} - self._root_nodes: Set[DAGNode] = None - self._leaf_nodes: Set[DAGNode] = None - self._trigger_nodes: Set[DAGNode] = None + self.node_name_to_node: Dict[str, DAGNode] = {} + self._root_nodes: List[DAGNode] = None + self._leaf_nodes: List[DAGNode] = None + self._trigger_nodes: List[DAGNode] = None def _append_node(self, node: DAGNode) -> None: + if node.node_id in self.node_map: + return + if node.node_name: + if node.node_name in self.node_name_to_node: + raise ValueError( + f"Node name {node.node_name} already exists in DAG {self.dag_id}" + ) + self.node_name_to_node[node.node_name] = node self.node_map[node.node_id] = node # clear cached nodes self._root_nodes = None @@ -336,22 +457,44 @@ def _build(self) -> None: @property def root_nodes(self) -> List[DAGNode]: + """The root nodes of current DAG + + Returns: + List[DAGNode]: The root nodes of current DAG, no repeat + """ if not self._root_nodes: self._build() return self._root_nodes @property def leaf_nodes(self) -> List[DAGNode]: + """The leaf nodes of current DAG + + Returns: + List[DAGNode]: The leaf nodes of current DAG, no repeat + """ if not self._leaf_nodes: self._build() return self._leaf_nodes @property - def trigger_nodes(self): + def trigger_nodes(self) -> List[DAGNode]: + """The trigger nodes of current DAG + + Returns: + List[DAGNode]: The trigger nodes of current DAG, no repeat + """ if not self._trigger_nodes: self._build() return self._trigger_nodes + async def _after_dag_end(self) -> None: + """The callback after DAG end""" + tasks = [] + for node in self.node_map.values(): + tasks.append(node.after_dag_end()) + await asyncio.gather(*tasks) + def __enter__(self): DAGVar.enter_dag(self) return self diff --git a/dbgpt/core/awel/operator/base.py b/dbgpt/core/awel/operator/base.py index c114412b0..cc106dbf6 100644 --- a/dbgpt/core/awel/operator/base.py +++ b/dbgpt/core/awel/operator/base.py @@ -146,6 +146,16 @@ def __init__( def current_dag_context(self) -> DAGContext: return self._dag_ctx + @property + def dev_mode(self) -> bool: + """Whether the operator is in dev mode. + In production mode, the default runner is not None. + + Returns: + bool: Whether the operator is in dev mode. True if the default runner is None. + """ + return default_runner is None + async def _run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: if not self.node_id: raise ValueError(f"The DAG Node ID can't be empty, current node {self}") diff --git a/dbgpt/core/awel/operator/common_operator.py b/dbgpt/core/awel/operator/common_operator.py index 8015368c8..bd1199aa7 100644 --- a/dbgpt/core/awel/operator/common_operator.py +++ b/dbgpt/core/awel/operator/common_operator.py @@ -1,4 +1,14 @@ -from typing import Generic, Dict, List, Union, Callable, Any, AsyncIterator, Awaitable +from typing import ( + Generic, + Dict, + List, + Union, + Callable, + Any, + AsyncIterator, + Awaitable, + Optional, +) import asyncio import logging @@ -162,7 +172,9 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]): """ def __init__( - self, branches: Dict[BranchFunc[IN], Union[BaseOperator, str]], **kwargs + self, + branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None, + **kwargs, ): """ Initializes a BranchDAGNode with a branching function. @@ -203,7 +215,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: branches = self._branches if not branches: - branches = await self.branchs() + branches = await self.branches() branch_func_tasks = [] branch_nodes: List[str] = [] @@ -229,7 +241,7 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx.update_metadata("skip_node_names", skip_node_names) return parent_output - async def branchs(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]: + async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]: raise NotImplementedError diff --git a/dbgpt/core/awel/runner/job_manager.py b/dbgpt/core/awel/runner/job_manager.py index 7a1d12ead..44c1af01b 100644 --- a/dbgpt/core/awel/runner/job_manager.py +++ b/dbgpt/core/awel/runner/job_manager.py @@ -1,7 +1,8 @@ +import asyncio from typing import List, Set, Optional, Dict import uuid import logging -from ..dag.base import DAG +from ..dag.base import DAG, DAGLifecycle from ..operator.base import BaseOperator, CALL_DATA @@ -18,18 +19,20 @@ def __init__(self, dag: DAG) -> None: self._dag = dag -class JobManager: +class JobManager(DAGLifecycle): def __init__( self, root_nodes: List[BaseOperator], all_nodes: List[BaseOperator], end_node: BaseOperator, id2call_data: Dict[str, Dict], + node_name_to_ids: Dict[str, str], ) -> None: self._root_nodes = root_nodes self._all_nodes = all_nodes self._end_node = end_node self._id2node_data = id2call_data + self._node_name_to_ids = node_name_to_ids @staticmethod def build_from_end_node( @@ -38,11 +41,31 @@ def build_from_end_node( nodes = _build_from_end_node(end_node) root_nodes = _get_root_nodes(nodes) id2call_data = _save_call_data(root_nodes, call_data) - return JobManager(root_nodes, nodes, end_node, id2call_data) + + node_name_to_ids = {} + for node in nodes: + if node.node_name is not None: + node_name_to_ids[node.node_name] = node.node_id + + return JobManager(root_nodes, nodes, end_node, id2call_data, node_name_to_ids) def get_call_data_by_id(self, node_id: str) -> Optional[Dict]: return self._id2node_data.get(node_id) + async def before_dag_run(self): + """The callback before DAG run""" + tasks = [] + for node in self._all_nodes: + tasks.append(node.before_dag_run()) + await asyncio.gather(*tasks) + + async def after_dag_end(self): + """The callback after DAG end""" + tasks = [] + for node in self._all_nodes: + tasks.append(node.after_dag_end()) + await asyncio.gather(*tasks) + def _save_call_data( root_nodes: List[BaseOperator], call_data: CALL_DATA @@ -66,6 +89,7 @@ def _save_call_data( def _build_from_end_node(end_node: BaseOperator) -> List[BaseOperator]: + """Build all nodes from the end node.""" nodes = [] if isinstance(end_node, BaseOperator): task_id = end_node.node_id diff --git a/dbgpt/core/awel/runner/local_runner.py b/dbgpt/core/awel/runner/local_runner.py index e29874ebe..1bc8fdc9d 100644 --- a/dbgpt/core/awel/runner/local_runner.py +++ b/dbgpt/core/awel/runner/local_runner.py @@ -1,7 +1,8 @@ from typing import Dict, Optional, Set, List import logging -from ..dag.base import DAGContext +from dbgpt.component import SystemApp +from ..dag.base import DAGContext, DAGVar from ..operator.base import WorkflowRunner, BaseOperator, CALL_DATA from ..operator.common_operator import BranchOperator, JoinOperator, TriggerOperator from ..task.base import TaskContext, TaskState @@ -18,19 +19,29 @@ async def execute_workflow( call_data: Optional[CALL_DATA] = None, streaming_call: bool = False, ) -> DAGContext: - # Create DAG context - dag_ctx = DAGContext(streaming_call=streaming_call) + # Save node output + # dag = node.dag + node_outputs: Dict[str, TaskContext] = {} job_manager = JobManager.build_from_end_node(node, call_data) + # Create DAG context + dag_ctx = DAGContext( + streaming_call=streaming_call, + node_to_outputs=node_outputs, + node_name_to_ids=job_manager._node_name_to_ids, + ) logger.info( f"Begin run workflow from end operator, id: {node.node_id}, call_data: {call_data}" ) - dag = node.dag - # Save node output - node_outputs: Dict[str, TaskContext] = {} skip_node_ids = set() + system_app: SystemApp = DAGVar.get_current_system_app() + + await job_manager.before_dag_run() await self._execute_node( - job_manager, node, dag_ctx, node_outputs, skip_node_ids + job_manager, node, dag_ctx, node_outputs, skip_node_ids, system_app ) + if not streaming_call and node.dag: + # streaming call not work for dag end + await node.dag._after_dag_end() return dag_ctx @@ -41,6 +52,7 @@ async def _execute_node( dag_ctx: DAGContext, node_outputs: Dict[str, TaskContext], skip_node_ids: Set[str], + system_app: SystemApp, ): # Skip run node if node.node_id in node_outputs: @@ -50,7 +62,12 @@ async def _execute_node( for upstream_node in node.upstream: if isinstance(upstream_node, BaseOperator): await self._execute_node( - job_manager, upstream_node, dag_ctx, node_outputs, skip_node_ids + job_manager, + upstream_node, + dag_ctx, + node_outputs, + skip_node_ids, + system_app, ) inputs = [ @@ -73,6 +90,9 @@ async def _execute_node( logger.debug( f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}" ) + if system_app is not None and node.system_app is None: + node.set_system_app(system_app) + await node._run(dag_ctx) node_outputs[node.node_id] = dag_ctx.current_task_context task_ctx.set_current_state(TaskState.SUCCESS) diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 58f0ba529..192165f10 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict +from typing import Union, Type, List, TYPE_CHECKING, Optional, Any, Dict, Callable from starlette.requests import Request from starlette.responses import Response from dbgpt._private.pydantic import BaseModel @@ -13,7 +13,8 @@ if TYPE_CHECKING: from fastapi import APIRouter, FastAPI -RequestBody = Union[Request, Type[BaseModel], str] +RequestBody = Union[Type[Request], Type[BaseModel], str] +StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool] logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ def __init__( methods: Optional[Union[str, List[str]]] = "GET", request_body: Optional[RequestBody] = None, streaming_response: Optional[bool] = False, + streaming_predict_func: Optional[StreamingPredictFunc] = None, response_model: Optional[Type] = None, response_headers: Optional[Dict[str, str]] = None, response_media_type: Optional[str] = None, @@ -39,6 +41,7 @@ def __init__( self._methods = methods self._req_body = request_body self._streaming_response = streaming_response + self._streaming_predict_func = streaming_predict_func self._response_model = response_model self._status_code = status_code self._router_tags = router_tags @@ -59,10 +62,13 @@ async def _request_body_dependency(request: Request): return await _parse_request_body(request, self._req_body) async def route_function(body=Depends(_request_body_dependency)): + streaming_response = self._streaming_response + if self._streaming_predict_func: + streaming_response = self._streaming_predict_func(body) return await _trigger_dag( body, self.dag, - self._streaming_response, + streaming_response, self._response_headers, self._response_media_type, ) @@ -112,6 +118,7 @@ async def _trigger_dag( response_headers: Optional[Dict[str, str]] = None, response_media_type: Optional[str] = None, ) -> Any: + from fastapi import BackgroundTasks from fastapi.responses import StreamingResponse end_node = dag.leaf_nodes @@ -131,8 +138,11 @@ async def _trigger_dag( "Transfer-Encoding": "chunked", } generator = await end_node.call_stream(call_data={"data": body}) + background_tasks = BackgroundTasks() + background_tasks.add_task(end_node.dag._after_dag_end) return StreamingResponse( generator, headers=headers, media_type=media_type, + background=background_tasks, ) diff --git a/dbgpt/core/interface/cache.py b/dbgpt/core/interface/cache.py index 539758d36..63babb7a2 100644 --- a/dbgpt/core/interface/cache.py +++ b/dbgpt/core/interface/cache.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod - -from typing import Any, TypeVar, Generic, Optional from dataclasses import dataclass from enum import Enum +from typing import Any, Generic, Optional, TypeVar from dbgpt.core.interface.serialization import Serializable diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index 49fe84aa7..ebc6088c2 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -1,14 +1,13 @@ -from abc import ABC, abstractmethod -from typing import Optional, Dict, List, Any, Union, AsyncIterator -import time -from dataclasses import dataclass, asdict, field import copy +import time +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, field +from typing import Any, AsyncIterator, Dict, List, Optional, Union +from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.util import BaseParameters from dbgpt.util.annotations import PublicAPI from dbgpt.util.model_utils import GPUInfo -from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType -from dbgpt.core.awel import MapOperator, StreamifyAbsOperator @dataclass @@ -97,6 +96,28 @@ def to_dict(self) -> Dict: return asdict(self) +@dataclass +@PublicAPI(stability="beta") +class ModelRequestContext: + stream: Optional[bool] = False + """Whether to return a stream of responses.""" + + user_name: Optional[str] = None + """The user name of the model request.""" + + sys_code: Optional[str] = None + """The system code of the model request.""" + + conv_uid: Optional[str] = None + """The conversation id of the model inference.""" + + span_id: Optional[str] = None + """The span id of the model inference.""" + + extra: Optional[Dict[str, Any]] = field(default_factory=dict) + """The extra information of the model inference.""" + + @dataclass @PublicAPI(stability="beta") class ModelOutput: @@ -145,6 +166,27 @@ class ModelRequest: span_id: Optional[str] = None """The span id of the model inference.""" + context: Optional[ModelRequestContext] = field( + default_factory=lambda: ModelRequestContext() + ) + """The context of the model inference.""" + + @property + def stream(self) -> bool: + """Whether to return a stream of responses.""" + return self.context and self.context.stream + + def copy(self): + new_request = copy.deepcopy(self) + # Transform messages to List[ModelMessage] + new_request.messages = list( + map( + lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m), + new_request.messages, + ) + ) + return new_request + def to_dict(self) -> Dict[str, Any]: new_reqeust = copy.deepcopy(self) new_reqeust.messages = list( @@ -161,6 +203,17 @@ def _get_messages(self) -> List[ModelMessage]: ) ) + def get_single_user_message(self) -> Optional[ModelMessage]: + """Get the single user message. + + Returns: + Optional[ModelMessage]: The single user message. + """ + messages = self._get_messages() + if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN: + raise ValueError("The messages is not a single user message") + return messages[0] + @staticmethod def _build(model: str, prompt: str, **kwargs): return ModelRequest( @@ -178,11 +231,22 @@ def to_openai_messages(self) -> List[Dict[str, Any]]: List[Dict[str, Any]]: The messages in the format of OpenAI API. Examples: + .. code-block:: python + + from dbgpt.core.interface.message import ( + ModelMessage, + ModelMessageRoleType, + ) + messages = [ ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hi"), - ModelMessage(role=ModelMessageRoleType.AI, content="Hi, I'm a robot.") - ModelMessage(role=ModelMessageRoleType.HUMAN, content="Who are your"), + ModelMessage( + role=ModelMessageRoleType.AI, content="Hi, I'm a robot." + ), + ModelMessage( + role=ModelMessageRoleType.HUMAN, content="Who are your" + ), ] openai_messages = ModelRequest.to_openai_messages(messages) assert openai_messages == [ @@ -272,63 +336,3 @@ async def count_token(self, model: str, prompt: str) -> int: Returns: int: The number of tokens. """ - - -class RequestBuildOperator(MapOperator[str, ModelRequest], ABC): - def __init__(self, model: str, **kwargs): - self._model = model - super().__init__(**kwargs) - - async def map(self, input_value: str) -> ModelRequest: - return ModelRequest._build(self._model, input_value) - - -class BaseLLM: - """The abstract operator for a LLM.""" - - def __init__(self, llm_client: Optional[LLMClient] = None): - self._llm_client = llm_client - - @property - def llm_client(self) -> LLMClient: - """Return the LLM client.""" - if not self._llm_client: - raise ValueError("llm_client is not set") - return self._llm_client - - -class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): - """The operator for a LLM. - - Args: - llm_client (LLMClient, optional): The LLM client. Defaults to None. - - This operator will generate a no streaming response. - """ - - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client=llm_client) - MapOperator.__init__(self, **kwargs) - - async def map(self, request: ModelRequest) -> ModelOutput: - return await self.llm_client.generate(request) - - -class StreamingLLMOperator( - BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC -): - """The streaming operator for a LLM. - - Args: - llm_client (LLMClient, optional): The LLM client. Defaults to None. - - This operator will generate streaming response. - """ - - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client=llm_client) - StreamifyAbsOperator.__init__(self, **kwargs) - - async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]: - async for output in self.llm_client.generate_stream(request): - yield output diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 2493ebb53..005452db2 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -1,16 +1,16 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, List, Tuple, Union, Optional from datetime import datetime +from typing import Dict, List, Optional, Tuple, Union from dbgpt._private.pydantic import BaseModel, Field - +from dbgpt.core.awel import MapOperator from dbgpt.core.interface.storage import ( + InMemoryStorage, ResourceIdentifier, - StorageItem, StorageInterface, - InMemoryStorage, + StorageItem, ) @@ -112,6 +112,7 @@ class ModelMessage(BaseModel): """Similar to openai's message format""" role: str content: str + round_index: Optional[int] = 0 @staticmethod def from_openai_messages( @@ -443,6 +444,7 @@ def from_conversation(self, conversation: OnceConversation) -> None: self.tokens = conversation.tokens self.user_name = conversation.user_name self.sys_code = conversation.sys_code + self._message_index = conversation._message_index def get_messages_by_round(self, round_index: int) -> List[BaseMessage]: """Get the messages by round index @@ -470,6 +472,7 @@ def get_messages_with_round(self, round_count: int) -> List[BaseMessage]: Example: .. code-block:: python + conversation = OnceConversation() conversation.start_new_round() conversation.add_user_message("hello, this is the first round") @@ -485,11 +488,17 @@ def get_messages_with_round(self, round_count: int) -> List[BaseMessage]: conversation.end_current_round() assert len(conversation.get_messages_with_round(1)) == 2 - assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round" + assert ( + conversation.get_messages_with_round(1)[0].content + == "hello, this is the third round" + ) assert conversation.get_messages_with_round(1)[1].content == "hi" assert len(conversation.get_messages_with_round(2)) == 4 - assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round" + assert ( + conversation.get_messages_with_round(2)[0].content + == "hello, this is the second round" + ) assert conversation.get_messages_with_round(2)[1].content == "hi" Args: @@ -517,6 +526,7 @@ def get_model_messages(self) -> List[ModelMessage]: Examples: If you not need the history messages, you can override this method like this: .. code-block:: python + def get_model_messages(self) -> List[ModelMessage]: messages = [] for message in self.get_latest_round(): @@ -528,6 +538,7 @@ def get_model_messages(self) -> List[ModelMessage]: If you want to add the one round history messages, you can override this method like this: .. code-block:: python + def get_model_messages(self) -> List[ModelMessage]: messages = [] latest_round_index = self.chat_order @@ -537,7 +548,9 @@ def get_model_messages(self) -> List[ModelMessage]: for message in self.get_messages_by_round(round_index): if message.pass_to_model: messages.append( - ModelMessage(role=message.type, content=message.content) + ModelMessage( + role=message.type, content=message.content + ) ) return messages @@ -548,7 +561,11 @@ def get_model_messages(self) -> List[ModelMessage]: for message in self.messages: if message.pass_to_model: messages.append( - ModelMessage(role=message.type, content=message.content) + ModelMessage( + role=message.type, + content=message.content, + round_index=message.round_index, + ) ) return messages @@ -780,6 +797,9 @@ def load_from_storage( ) messages = [message.to_message() for message in message_list] conversation.messages = messages + # This index is used to save the message to the storage(Has not been saved) + # The new message append to the messages, so the index is len(messages) + conversation._message_index = len(messages) self._message_ids = message_ids self._has_stored_message_index = len(messages) - 1 self.from_conversation(conversation) diff --git a/dbgpt/core/interface/operator/__init__.py b/dbgpt/core/interface/operator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/core/interface/operator/llm_operator.py b/dbgpt/core/interface/operator/llm_operator.py new file mode 100644 index 000000000..fc117ddc5 --- /dev/null +++ b/dbgpt/core/interface/operator/llm_operator.py @@ -0,0 +1,166 @@ +import dataclasses +from abc import ABC +from typing import Any, AsyncIterator, Dict, Optional, Union + +from dbgpt._private.pydantic import BaseModel +from dbgpt.core.awel import ( + BranchFunc, + BranchOperator, + MapOperator, + StreamifyAbsOperator, +) +from dbgpt.core.interface.llm import ( + LLMClient, + ModelOutput, + ModelRequest, + ModelRequestContext, +) +from dbgpt.core.interface.message import ModelMessage + +RequestInput = Union[ + ModelRequest, + str, + Dict[str, Any], + BaseModel, +] + + +class RequestBuildOperator(MapOperator[RequestInput, ModelRequest], ABC): + def __init__(self, model: Optional[str] = None, **kwargs): + self._model = model + super().__init__(**kwargs) + + async def map(self, input_value: RequestInput) -> ModelRequest: + req_dict = {} + if isinstance(input_value, str): + req_dict = {"messages": [ModelMessage.build_human_message(input_value)]} + elif isinstance(input_value, dict): + req_dict = input_value + elif dataclasses.is_dataclass(input_value): + req_dict = dataclasses.asdict(input_value) + elif isinstance(input_value, BaseModel): + req_dict = input_value.dict() + elif isinstance(input_value, ModelRequest): + if not input_value.model: + input_value.model = self._model + return input_value + if "messages" not in req_dict: + raise ValueError("messages is not set") + messages = req_dict["messages"] + if isinstance(messages, str): + # Single message, transform to a list including one human message + req_dict["messages"] = [ModelMessage.build_human_message(messages)] + if "model" not in req_dict: + req_dict["model"] = self._model + if not req_dict["model"]: + raise ValueError("model is not set") + stream = False + has_stream = False + if "stream" in req_dict: + has_stream = True + stream = req_dict["stream"] + del req_dict["stream"] + if "context" not in req_dict: + req_dict["context"] = ModelRequestContext(stream=stream) + else: + context_dict = req_dict["context"] + if not isinstance(context_dict, dict): + raise ValueError("context is not a dict") + if has_stream: + context_dict["stream"] = stream + req_dict["context"] = ModelRequestContext(**context_dict) + return ModelRequest(**req_dict) + + +class BaseLLM: + """The abstract operator for a LLM.""" + + SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name" + + def __init__(self, llm_client: Optional[LLMClient] = None): + self._llm_client = llm_client + + @property + def llm_client(self) -> LLMClient: + """Return the LLM client.""" + if not self._llm_client: + raise ValueError("llm_client is not set") + return self._llm_client + + +class LLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): + """The operator for a LLM. + + Args: + llm_client (LLMClient, optional): The LLM client. Defaults to None. + + This operator will generate a no streaming response. + """ + + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client=llm_client) + MapOperator.__init__(self, **kwargs) + + async def map(self, request: ModelRequest) -> ModelOutput: + await self.current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_MODEL_NAME, request.model + ) + return await self.llm_client.generate(request) + + +class StreamingLLMOperator( + BaseLLM, StreamifyAbsOperator[ModelRequest, ModelOutput], ABC +): + """The streaming operator for a LLM. + + Args: + llm_client (LLMClient, optional): The LLM client. Defaults to None. + + This operator will generate streaming response. + """ + + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client=llm_client) + StreamifyAbsOperator.__init__(self, **kwargs) + + async def streamify(self, request: ModelRequest) -> AsyncIterator[ModelOutput]: + await self.current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_MODEL_NAME, request.model + ) + async for output in self.llm_client.generate_stream(request): + yield output + + +class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]): + """Branch operator for LLM. + + This operator will branch the workflow based on the stream flag of the request. + """ + + def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs): + super().__init__(**kwargs) + if not stream_task_name: + raise ValueError("stream_task_name is not set") + if not no_stream_task_name: + raise ValueError("no_stream_task_name is not set") + self._stream_task_name = stream_task_name + self._no_stream_task_name = no_stream_task_name + + async def branches(self) -> Dict[BranchFunc[ModelRequest], str]: + """ + Return a dict of branch function and task name. + + Returns: + Dict[BranchFunc[ModelRequest], str]: A dict of branch function and task name. + the key is a predicate function, the value is the task name. If the predicate function returns True, + we will run the corresponding task. + """ + + async def check_stream_true(r: ModelRequest) -> bool: + # If stream is true, we will run the streaming task. otherwise, we will run the non-streaming task. + return r.stream + + return { + check_stream_true: self._stream_task_name, + lambda x: not x.stream: self._no_stream_task_name, + } diff --git a/dbgpt/core/interface/operator/message_operator.py b/dbgpt/core/interface/operator/message_operator.py new file mode 100644 index 000000000..d775e8401 --- /dev/null +++ b/dbgpt/core/interface/operator/message_operator.py @@ -0,0 +1,321 @@ +import uuid +from abc import ABC, abstractmethod +from typing import Any, AsyncIterator, List, Optional + +from dbgpt.core import ( + MessageStorageItem, + ModelMessage, + ModelOutput, + ModelRequest, + ModelRequestContext, + StorageConversation, + StorageInterface, +) +from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator + + +class BaseConversationOperator(BaseOperator, ABC): + """Base class for conversation operators.""" + + SHARE_DATA_KEY_STORAGE_CONVERSATION = "share_data_key_storage_conversation" + SHARE_DATA_KEY_MODEL_REQUEST = "share_data_key_model_request" + + def __init__( + self, + storage: Optional[StorageInterface[StorageConversation, Any]] = None, + message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, + **kwargs + ): + super().__init__(**kwargs) + self._storage = storage + self._message_storage = message_storage + + @property + def storage(self) -> StorageInterface[StorageConversation, Any]: + """Return the LLM client.""" + if not self._storage: + raise ValueError("Storage is not set") + return self._storage + + @property + def message_storage(self) -> StorageInterface[MessageStorageItem, Any]: + """Return the LLM client.""" + if not self._message_storage: + raise ValueError("Message storage is not set") + return self._message_storage + + async def get_storage_conversation(self) -> StorageConversation: + """Get the storage conversation from share data. + + Returns: + StorageConversation: The storage conversation. + """ + storage_conv: StorageConversation = ( + await self.current_dag_context.get_from_share_data( + self.SHARE_DATA_KEY_STORAGE_CONVERSATION + ) + ) + if not storage_conv: + raise ValueError("Storage conversation is not set") + return storage_conv + + async def get_model_request(self) -> ModelRequest: + """Get the model request from share data. + + Returns: + ModelRequest: The model request. + """ + model_request: ModelRequest = ( + await self.current_dag_context.get_from_share_data( + self.SHARE_DATA_KEY_MODEL_REQUEST + ) + ) + if not model_request: + raise ValueError("Model request is not set") + return model_request + + +class PreConversationOperator( + BaseConversationOperator, MapOperator[ModelRequest, ModelRequest] +): + """The operator to prepare the storage conversation. + + In DB-GPT, conversation record and the messages in the conversation are stored in the storage, + and they can store in different storage(for high performance). + """ + + def __init__( + self, + storage: Optional[StorageInterface[StorageConversation, Any]] = None, + message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None, + **kwargs + ): + super().__init__(storage=storage, message_storage=message_storage) + MapOperator.__init__(self, **kwargs) + + async def map(self, input_value: ModelRequest) -> ModelRequest: + """Map the input value to a ModelRequest. + + Args: + input_value (ModelRequest): The input value. + + Returns: + ModelRequest: The mapped ModelRequest. + """ + if input_value.context is None: + input_value.context = ModelRequestContext() + if not input_value.context.conv_uid: + input_value.context.conv_uid = str(uuid.uuid4()) + if not input_value.context.extra: + input_value.context.extra = {} + + chat_mode = input_value.context.extra.get("chat_mode") + + # Create a new storage conversation, this will load the conversation from storage, so we must do this async + storage_conv: StorageConversation = await self.blocking_func_to_async( + StorageConversation, + conv_uid=input_value.context.conv_uid, + chat_mode=chat_mode, + user_name=input_value.context.user_name, + sys_code=input_value.context.sys_code, + conv_storage=self.storage, + message_storage=self.message_storage, + ) + # The input message must be a single user message + single_human_message: ModelMessage = input_value.get_single_user_message() + storage_conv.start_new_round() + storage_conv.add_user_message(single_human_message.content) + + # Get all messages from current storage conversation, and overwrite the input value + messages: List[ModelMessage] = storage_conv.get_model_messages() + input_value.messages = messages + + # Save the storage conversation to share data, for the child operators + await self.current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv + ) + await self.current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_MODEL_REQUEST, input_value + ) + return input_value + + async def after_dag_end(self): + """The callback after DAG end""" + # Save the storage conversation to storage after the whole DAG finished + storage_conv: StorageConversation = await self.get_storage_conversation() + # TODO dont save if the conversation has some internal error + storage_conv.end_current_round() + + +class PostConversationOperator( + BaseConversationOperator, MapOperator[ModelOutput, ModelOutput] +): + def __init__(self, **kwargs): + MapOperator.__init__(self, **kwargs) + + async def map(self, input_value: ModelOutput) -> ModelOutput: + """Map the input value to a ModelOutput. + + Args: + input_value (ModelOutput): The input value. + + Returns: + ModelOutput: The mapped ModelOutput. + """ + # Get the storage conversation from share data + storage_conv: StorageConversation = await self.get_storage_conversation() + storage_conv.add_ai_message(input_value.text) + return input_value + + +class PostStreamingConversationOperator( + BaseConversationOperator, TransformStreamAbsOperator[ModelOutput, ModelOutput] +): + def __init__(self, **kwargs): + TransformStreamAbsOperator.__init__(self, **kwargs) + + async def transform_stream( + self, input_value: AsyncIterator[ModelOutput] + ) -> ModelOutput: + """Transform the input value to a ModelOutput. + + Args: + input_value (ModelOutput): The input value. + + Returns: + ModelOutput: The transformed ModelOutput. + """ + full_text = "" + async for model_output in input_value: + # Now model_output.text if full text, if it is a delta text, we should merge all delta text to a full text + full_text = model_output.text + yield model_output + # Get the storage conversation from share data + storage_conv: StorageConversation = await self.get_storage_conversation() + storage_conv.add_ai_message(full_text) + + +class ConversationMapperOperator( + BaseConversationOperator, MapOperator[ModelRequest, ModelRequest] +): + def __init__(self, **kwargs): + MapOperator.__init__(self, **kwargs) + + async def map(self, input_value: ModelRequest) -> ModelRequest: + """Map the input value to a ModelRequest. + + Args: + input_value (ModelRequest): The input value. + + Returns: + ModelRequest: The mapped ModelRequest. + """ + input_value = input_value.copy() + messages: List[ModelMessage] = await self.map_messages(input_value.messages) + # Overwrite the input value + input_value.messages = messages + return input_value + + async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]: + """Map the input messages to a list of ModelMessage. + + Args: + messages (List[ModelMessage]): The input messages. + + Returns: + List[ModelMessage]: The mapped ModelMessage. + """ + return messages + + def _split_messages_by_round( + self, messages: List[ModelMessage] + ) -> List[List[ModelMessage]]: + """Split the messages by round index. + + Args: + messages (List[ModelMessage]): The input messages. + + Returns: + List[List[ModelMessage]]: The splitted messages. + """ + messages_by_round: List[List[ModelMessage]] = [] + last_round_index = 0 + for message in messages: + if not message.round_index: + # Round index must bigger than 0 + raise ValueError("Message round_index is not set") + if message.round_index > last_round_index: + last_round_index = message.round_index + messages_by_round.append([]) + messages_by_round[-1].append(message) + return messages_by_round + + +class BufferedConversationMapperOperator(ConversationMapperOperator): + """The buffered conversation mapper operator. + + This Operator must be used after the PreConversationOperator, + and it will map the messages in the storage conversation. + + Examples: + + Transform no history messages + + .. code-block:: python + + import asyncio + from dbgpt.core import ModelMessage + from dbgpt.core.operator import BufferedConversationMapperOperator + + # No history + messages = [ModelMessage(role="human", content="Hello", round_index=1)] + operator = BufferedConversationMapperOperator(last_k_round=1) + messages = asyncio.run(operator.map_messages(messages)) + assert messages == [ModelMessage(role="human", content="Hello", round_index=1)] + + Transform with history messages + + .. code-block:: python + + # With history + messages = [ + ModelMessage(role="human", content="Hi", round_index=1), + ModelMessage(role="ai", content="Hello!", round_index=1), + ModelMessage(role="system", content="Error 404", round_index=2), + ModelMessage(role="human", content="What's the error?", round_index=2), + ModelMessage(role="ai", content="Just a joke.", round_index=2), + ModelMessage(role="human", content="Funny!", round_index=3), + ] + operator = BufferedConversationMapperOperator(last_k_round=1) + messages = asyncio.run(operator.map_messages(messages)) + # Just keep the last one round, so the first round messages will be removed + # Note: The round index 3 is not a complete round + assert messages == [ + ModelMessage(role="system", content="Error 404", round_index=2), + ModelMessage(role="human", content="What's the error?", round_index=2), + ModelMessage(role="ai", content="Just a joke.", round_index=2), + ModelMessage(role="human", content="Funny!", round_index=3), + ] + """ + + def __init__(self, last_k_round: Optional[int] = 2, **kwargs): + super().__init__(**kwargs) + self._last_k_round = last_k_round + + async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]: + """Map the input messages to a list of ModelMessage. + + Args: + messages (List[ModelMessage]): The input messages. + + Returns: + List[ModelMessage]: The mapped ModelMessage. + """ + messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round( + messages + ) + # Get the last k round messages + index = self._last_k_round + 1 + messages_by_round = messages_by_round[-index:] + messages: List[ModelMessage] = sum(messages_by_round, []) + return messages diff --git a/dbgpt/core/interface/output_parser.py b/dbgpt/core/interface/output_parser.py index 3d2b67027..1fdac4510 100644 --- a/dbgpt/core/interface/output_parser.py +++ b/dbgpt/core/interface/output_parser.py @@ -1,13 +1,13 @@ from __future__ import annotations import json -from abc import ABC import logging +from abc import ABC from dataclasses import asdict from typing import Any, Dict, TypeVar, Union -from dbgpt.core.awel import MapOperator from dbgpt.core import ModelOutput +from dbgpt.core.awel import MapOperator T = TypeVar("T") ResponseTye = Union[str, bytes, ModelOutput] diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py index cc4ae66a3..e55869953 100644 --- a/dbgpt/core/interface/prompt.py +++ b/dbgpt/core/interface/prompt.py @@ -1,12 +1,20 @@ +import dataclasses import json -from abc import ABC +from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional -from dbgpt._private.pydantic import BaseModel -from dbgpt.util.formatting import formatter, no_strict_formatter +from dbgpt._private.pydantic import BaseModel +from dbgpt.core._private.example_base import ExampleSelector from dbgpt.core.awel import MapOperator from dbgpt.core.interface.output_parser import BaseOutputParser -from dbgpt.core._private.example_base import ExampleSelector +from dbgpt.core.interface.storage import ( + InMemoryStorage, + QuerySpec, + ResourceIdentifier, + StorageInterface, + StorageItem, +) +from dbgpt.util.formatting import formatter, no_strict_formatter def _jinja2_formatter(template: str, **kwargs: Any) -> str: @@ -86,6 +94,434 @@ def from_template(template: str) -> "PromptTemplateOperator": ) +@dataclasses.dataclass +class PromptTemplateIdentifier(ResourceIdentifier): + identifier_split: str = dataclasses.field(default="___$$$$___", init=False) + prompt_name: str + prompt_language: Optional[str] = None + sys_code: Optional[str] = None + model: Optional[str] = None + + def __post_init__(self): + if self.prompt_name is None: + raise ValueError("prompt_name cannot be None") + + if any( + self.identifier_split in key + for key in [ + self.prompt_name, + self.prompt_language, + self.sys_code, + self.model, + ] + if key is not None + ): + raise ValueError( + f"identifier_split {self.identifier_split} is not allowed in prompt_name, prompt_language, sys_code, model" + ) + + @property + def str_identifier(self) -> str: + return self.identifier_split.join( + key + for key in [ + self.prompt_name, + self.prompt_language, + self.sys_code, + self.model, + ] + if key is not None + ) + + def to_dict(self) -> Dict: + return { + "prompt_name": self.prompt_name, + "prompt_language": self.prompt_language, + "sys_code": self.sys_code, + "model": self.model, + } + + +@dataclasses.dataclass +class StoragePromptTemplate(StorageItem): + prompt_name: str + content: Optional[str] = None + prompt_language: Optional[str] = None + prompt_format: Optional[str] = None + input_variables: Optional[str] = None + model: Optional[str] = None + chat_scene: Optional[str] = None + sub_chat_scene: Optional[str] = None + prompt_type: Optional[str] = None + user_name: Optional[str] = None + sys_code: Optional[str] = None + _identifier: PromptTemplateIdentifier = dataclasses.field(init=False) + + def __post_init__(self): + self._identifier = PromptTemplateIdentifier( + prompt_name=self.prompt_name, + prompt_language=self.prompt_language, + sys_code=self.sys_code, + model=self.model, + ) + self._check() # Assuming _check() is a method you need to call after initialization + + def to_prompt_template(self) -> PromptTemplate: + """Convert the storage prompt template to a prompt template.""" + input_variables = ( + None + if not self.input_variables + else self.input_variables.strip().split(",") + ) + return PromptTemplate( + input_variables=input_variables, + template=self.content, + template_scene=self.chat_scene, + prompt_name=self.prompt_name, + template_format=self.prompt_format, + ) + + @staticmethod + def from_prompt_template( + prompt_template: PromptTemplate, + prompt_name: str, + prompt_language: Optional[str] = None, + prompt_type: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + sub_chat_scene: Optional[str] = None, + model: Optional[str] = None, + **kwargs, + ) -> "StoragePromptTemplate": + """Convert a prompt template to a storage prompt template. + + Args: + prompt_template (PromptTemplate): The prompt template to convert from. + prompt_name (str): The name of the prompt. + prompt_language (Optional[str], optional): The language of the prompt. Defaults to None. e.g. zh-cn, en. + prompt_type (Optional[str], optional): The type of the prompt. Defaults to None. e.g. common, private. + sys_code (Optional[str], optional): The system code of the prompt. Defaults to None. + user_name (Optional[str], optional): The username of the prompt. Defaults to None. + sub_chat_scene (Optional[str], optional): The sub chat scene of the prompt. Defaults to None. + model (Optional[str], optional): The model name of the prompt. Defaults to None. + kwargs (Dict): Other params to build the storage prompt template. + """ + input_variables = prompt_template.input_variables or kwargs.get( + "input_variables" + ) + if input_variables and isinstance(input_variables, list): + input_variables = ",".join(input_variables) + return StoragePromptTemplate( + prompt_name=prompt_name, + sys_code=sys_code, + user_name=user_name, + input_variables=input_variables, + model=model, + content=prompt_template.template or kwargs.get("content"), + prompt_language=prompt_language, + prompt_format=prompt_template.template_format + or kwargs.get("prompt_format"), + chat_scene=prompt_template.template_scene or kwargs.get("chat_scene"), + sub_chat_scene=sub_chat_scene, + prompt_type=prompt_type, + ) + + @property + def identifier(self) -> PromptTemplateIdentifier: + return self._identifier + + def merge(self, other: "StorageItem") -> None: + """Merge the other item into the current item. + + Args: + other (StorageItem): The other item to merge + """ + if not isinstance(other, StoragePromptTemplate): + raise ValueError( + f"Cannot merge {type(other)} into {type(self)} because they are not the same type." + ) + self.from_object(other) + + def to_dict(self) -> Dict: + return { + "prompt_name": self.prompt_name, + "content": self.content, + "prompt_language": self.prompt_language, + "prompt_format": self.prompt_format, + "input_variables": self.input_variables, + "model": self.model, + "chat_scene": self.chat_scene, + "sub_chat_scene": self.sub_chat_scene, + "prompt_type": self.prompt_type, + "user_name": self.user_name, + "sys_code": self.sys_code, + } + + def _check(self): + if self.prompt_name is None: + raise ValueError("prompt_name cannot be None") + if self.content is None: + raise ValueError("content cannot be None") + + def from_object(self, template: "StoragePromptTemplate") -> None: + """Load the prompt template from an existing prompt template object. + + Args: + template (PromptTemplate): The prompt template to load from. + """ + self.content = template.content + self.prompt_format = template.prompt_format + self.input_variables = template.input_variables + self.model = template.model + self.chat_scene = template.chat_scene + self.sub_chat_scene = template.sub_chat_scene + self.prompt_type = template.prompt_type + self.user_name = template.user_name + + +class PromptManager: + """The manager class for prompt templates. + + Simple wrapper for the storage interface. + + Examples: + + .. code-block:: python + + # Default use InMemoryStorage + prompt_manager = PromptManager() + prompt_template = PromptTemplate( + template="hello {input}", + input_variables=["input"], + template_scene="chat_normal", + ) + prompt_manager.save(prompt_template, prompt_name="hello") + prompt_template_list = prompt_manager.list() + prompt_template_list = prompt_manager.prefer_query("hello") + + With a custom storage interface. + + .. code-block:: python + + from dbgpt.core.interface.storage import InMemoryStorage + + prompt_manager = PromptManager(InMemoryStorage()) + prompt_template = PromptTemplate( + template="hello {input}", + input_variables=["input"], + template_scene="chat_normal", + ) + prompt_manager.save(prompt_template, prompt_name="hello") + prompt_template_list = prompt_manager.list() + prompt_template_list = prompt_manager.prefer_query("hello") + + + """ + + def __init__( + self, storage: Optional[StorageInterface[StoragePromptTemplate, Any]] = None + ): + if storage is None: + storage = InMemoryStorage() + self._storage = storage + + @property + def storage(self) -> StorageInterface[StoragePromptTemplate, Any]: + """The storage interface for prompt templates.""" + return self._storage + + def prefer_query( + self, + prompt_name: str, + sys_code: Optional[str] = None, + prefer_prompt_language: Optional[str] = None, + prefer_model: Optional[str] = None, + **kwargs, + ) -> List[StoragePromptTemplate]: + """Query prompt templates from storage with prefer params. + + Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model). + This method will query prompt templates with prefer params first, if not found, will query all prompt templates. + + Examples: + + Query a prompt template. + .. code-block:: python + + prompt_template_list = prompt_manager.prefer_query("hello") + + Query with sys_code and username. + + .. code-block:: python + + prompt_template_list = prompt_manager.prefer_query( + "hello", sys_code="sys_code", user_name="user_name" + ) + + Query with prefer prompt language. + + .. code-block:: python + + # First query with prompt name "hello" exactly. + # Second filter with prompt language "zh-cn", if not found, will return all prompt templates. + prompt_template_list = prompt_manager.prefer_query( + "hello", prefer_prompt_language="zh-cn" + ) + + Query with prefer model. + + .. code-block:: python + + # First query with prompt name "hello" exactly. + # Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates. + prompt_template_list = prompt_manager.prefer_query( + "hello", prefer_model="vicuna-13b-v1.5" + ) + + Args: + prompt_name (str): The name of the prompt template. + sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None. + prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None. + prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None. + kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly). + """ + query_spec = QuerySpec( + conditions={ + "prompt_name": prompt_name, + "sys_code": sys_code, + **kwargs, + } + ) + queries: List[StoragePromptTemplate] = self.storage.query( + query_spec, StoragePromptTemplate + ) + if not queries: + return [] + if prefer_prompt_language: + prefer_prompt_language = prefer_prompt_language.lower() + temp_queries = [ + query + for query in queries + if query.prompt_language + and query.prompt_language.lower() == prefer_prompt_language + ] + if temp_queries: + queries = temp_queries + if prefer_model: + prefer_model = prefer_model.lower() + temp_queries = [ + query + for query in queries + if query.model and query.model.lower() == prefer_model + ] + if temp_queries: + queries = temp_queries + return queries + + def save(self, prompt_template: PromptTemplate, prompt_name: str, **kwargs) -> None: + """Save a prompt template to storage. + + Examples: + + .. code-block:: python + + prompt_template = PromptTemplate( + template="hello {input}", + input_variables=["input"], + template_scene="chat_normal", + prompt_name="hello", + ) + prompt_manager.save(prompt_template) + + Save with sys_code and username. + + .. code-block:: python + + prompt_template = PromptTemplate( + template="hello {input}", + input_variables=["input"], + template_scene="chat_normal", + prompt_name="hello", + ) + prompt_manager.save( + prompt_template, sys_code="sys_code", user_name="user_name" + ) + + Args: + prompt_template (PromptTemplate): The prompt template to save. + prompt_name (str): The name of the prompt template. + kwargs (Dict): Other params to build the storage prompt template. + More details in :meth:`~StoragePromptTemplate.from_prompt_template`. + """ + storage_prompt_template = StoragePromptTemplate.from_prompt_template( + prompt_template, prompt_name, **kwargs + ) + self.storage.save(storage_prompt_template) + + def list(self, **kwargs) -> List[StoragePromptTemplate]: + """List prompt templates from storage. + + Examples: + + List all prompt templates. + .. code-block:: python + + all_prompt_templates = prompt_manager.list() + + List with sys_code and username. + + .. code-block:: python + + templates = prompt_manager.list( + sys_code="sys_code", user_name="user_name" + ) + + Args: + kwargs (Dict): Other query params. + """ + query_spec = QuerySpec(conditions=kwargs) + return self.storage.query(query_spec, StoragePromptTemplate) + + def delete( + self, + prompt_name: str, + prompt_language: Optional[str] = None, + sys_code: Optional[str] = None, + model: Optional[str] = None, + ) -> None: + """Delete a prompt template from storage. + + Examples: + + Delete a prompt template. + + .. code-block:: python + + prompt_manager.delete("hello") + + Delete with sys_code and username. + + .. code-block:: python + + prompt_manager.delete( + "hello", sys_code="sys_code", user_name="user_name" + ) + + Args: + prompt_name (str): The name of the prompt template. + prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None. + sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None. + model (Optional[str], optional): The model of the prompt template. Defaults to None. + """ + identifier = PromptTemplateIdentifier( + prompt_name=prompt_name, + prompt_language=prompt_language, + sys_code=sys_code, + model=model, + ) + self.storage.delete(identifier) + + class PromptTemplateOperator(MapOperator[Dict, str]): def __init__(self, prompt_template: PromptTemplate, **kwargs: Any): super().__init__(**kwargs) diff --git a/dbgpt/core/interface/retriever.py b/dbgpt/core/interface/retriever.py index ec6e6650d..385295534 100644 --- a/dbgpt/core/interface/retriever.py +++ b/dbgpt/core/interface/retriever.py @@ -1,4 +1,5 @@ from abc import abstractmethod + from dbgpt.core.awel import MapOperator from dbgpt.core.awel.task.base import IN, OUT diff --git a/dbgpt/core/interface/serialization.py b/dbgpt/core/interface/serialization.py index e26ec5735..b1d8d60eb 100644 --- a/dbgpt/core/interface/serialization.py +++ b/dbgpt/core/interface/serialization.py @@ -1,6 +1,7 @@ from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Type, Dict +from typing import Dict, Type class Serializable(ABC): diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py index f8e722f59..d05258e9b 100644 --- a/dbgpt/core/interface/storage.py +++ b/dbgpt/core/interface/storage.py @@ -1,10 +1,10 @@ -from typing import Generic, TypeVar, Type, Optional, Dict, Any, List from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar from dbgpt.core.interface.serialization import Serializable, Serializer -from dbgpt.util.serialization.json_serialization import JsonSerializer from dbgpt.util.annotations import PublicAPI from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.util.serialization.json_serialization import JsonSerializer @PublicAPI(stability="beta") diff --git a/dbgpt/core/interface/tests/test_message.py b/dbgpt/core/interface/tests/test_message.py index 0650b1f67..98974e4bf 100755 --- a/dbgpt/core/interface/tests/test_message.py +++ b/dbgpt/core/interface/tests/test_message.py @@ -1,7 +1,7 @@ import pytest -from dbgpt.core.interface.tests.conftest import in_memory_storage from dbgpt.core.interface.message import * +from dbgpt.core.interface.tests.conftest import in_memory_storage @pytest.fixture diff --git a/dbgpt/core/interface/tests/test_prompt.py b/dbgpt/core/interface/tests/test_prompt.py new file mode 100644 index 000000000..e1a449013 --- /dev/null +++ b/dbgpt/core/interface/tests/test_prompt.py @@ -0,0 +1,320 @@ +import json + +import pytest + +from dbgpt.core.interface.prompt import ( + PromptManager, + PromptTemplate, + StoragePromptTemplate, +) +from dbgpt.core.interface.storage import QuerySpec +from dbgpt.core.interface.tests.conftest import in_memory_storage + + +@pytest.fixture +def sample_storage_prompt_template(): + return StoragePromptTemplate( + prompt_name="test_prompt", + content="Sample content, {var1}, {var2}", + prompt_language="en", + prompt_format="f-string", + input_variables="var1,var2", + model="model1", + chat_scene="scene1", + sub_chat_scene="subscene1", + prompt_type="type1", + user_name="user1", + sys_code="code1", + ) + + +@pytest.fixture +def complex_storage_prompt_template(): + content = """Database name: {db_name} Table structure definition: {table_info} User Question:{user_input}""" + return StoragePromptTemplate( + prompt_name="chat_data_auto_execute_prompt", + content=content, + prompt_language="en", + prompt_format="f-string", + input_variables="db_name,table_info,user_input", + model="vicuna-13b-v1.5", + chat_scene="chat_data", + sub_chat_scene="subscene1", + prompt_type="common", + user_name="zhangsan", + sys_code="dbgpt", + ) + + +@pytest.fixture +def prompt_manager(in_memory_storage): + return PromptManager(storage=in_memory_storage) + + +class TestPromptTemplate: + @pytest.mark.parametrize( + "template_str, input_vars, expected_output", + [ + ("Hello {name}", {"name": "World"}, "Hello World"), + ("{greeting}, {name}", {"greeting": "Hi", "name": "Alice"}, "Hi, Alice"), + ], + ) + def test_format_f_string(self, template_str, input_vars, expected_output): + prompt = PromptTemplate( + template=template_str, + input_variables=list(input_vars.keys()), + template_format="f-string", + ) + formatted_output = prompt.format(**input_vars) + assert formatted_output == expected_output + + @pytest.mark.parametrize( + "template_str, input_vars, expected_output", + [ + ("Hello {{ name }}", {"name": "World"}, "Hello World"), + ( + "{{ greeting }}, {{ name }}", + {"greeting": "Hi", "name": "Alice"}, + "Hi, Alice", + ), + ], + ) + def test_format_jinja2(self, template_str, input_vars, expected_output): + prompt = PromptTemplate( + template=template_str, + input_variables=list(input_vars.keys()), + template_format="jinja2", + ) + formatted_output = prompt.format(**input_vars) + assert formatted_output == expected_output + + def test_format_with_response_format(self): + template_str = "Response: {response}" + prompt = PromptTemplate( + template=template_str, + input_variables=["response"], + template_format="f-string", + response_format=json.dumps({"message": "hello"}), + ) + formatted_output = prompt.format(response="hello") + assert "Response: " in formatted_output + + def test_from_template(self): + template_str = "Hello {name}" + prompt = PromptTemplate.from_template(template_str) + assert prompt._prompt_template.template == template_str + assert prompt._prompt_template.input_variables == [] + + def test_format_missing_variable(self): + template_str = "Hello {name}" + prompt = PromptTemplate( + template=template_str, input_variables=["name"], template_format="f-string" + ) + with pytest.raises(KeyError): + prompt.format() + + def test_format_extra_variable(self): + template_str = "Hello {name}" + prompt = PromptTemplate( + template=template_str, + input_variables=["name"], + template_format="f-string", + template_is_strict=False, + ) + formatted_output = prompt.format(name="World", extra="unused") + assert formatted_output == "Hello World" + + def test_format_complex(self, complex_storage_prompt_template): + prompt = complex_storage_prompt_template.to_prompt_template() + formatted_output = prompt.format( + db_name="db1", + table_info="create table users(id int, name varchar(20))", + user_input="find all users whose name is 'Alice'", + ) + assert ( + formatted_output + == "Database name: db1 Table structure definition: create table users(id int, name varchar(20)) " + "User Question:find all users whose name is 'Alice'" + ) + + +class TestStoragePromptTemplate: + def test_constructor_and_properties(self): + storage_item = StoragePromptTemplate( + prompt_name="test", + content="Hello {name}", + prompt_language="en", + prompt_format="f-string", + input_variables="name", + model="model1", + chat_scene="chat", + sub_chat_scene="sub_chat", + prompt_type="type", + user_name="user", + sys_code="sys", + ) + assert storage_item.prompt_name == "test" + assert storage_item.content == "Hello {name}" + assert storage_item.prompt_language == "en" + assert storage_item.prompt_format == "f-string" + assert storage_item.input_variables == "name" + assert storage_item.model == "model1" + + def test_constructor_exceptions(self): + with pytest.raises(ValueError): + StoragePromptTemplate(prompt_name=None, content="Hello") + + def test_to_prompt_template(self, sample_storage_prompt_template): + prompt_template = sample_storage_prompt_template.to_prompt_template() + assert isinstance(prompt_template, PromptTemplate) + assert prompt_template.template == "Sample content, {var1}, {var2}" + assert prompt_template.input_variables == ["var1", "var2"] + + def test_from_prompt_template(self): + prompt_template = PromptTemplate( + template="Sample content, {var1}, {var2}", + input_variables=["var1", "var2"], + template_format="f-string", + ) + storage_prompt_template = StoragePromptTemplate.from_prompt_template( + prompt_template=prompt_template, prompt_name="test_prompt" + ) + assert storage_prompt_template.prompt_name == "test_prompt" + assert storage_prompt_template.content == "Sample content, {var1}, {var2}" + assert storage_prompt_template.input_variables == "var1,var2" + + def test_merge(self, sample_storage_prompt_template): + other = StoragePromptTemplate( + prompt_name="other_prompt", + content="Other content", + ) + sample_storage_prompt_template.merge(other) + assert sample_storage_prompt_template.content == "Other content" + + def test_to_dict(self, sample_storage_prompt_template): + result = sample_storage_prompt_template.to_dict() + assert result == { + "prompt_name": "test_prompt", + "content": "Sample content, {var1}, {var2}", + "prompt_language": "en", + "prompt_format": "f-string", + "input_variables": "var1,var2", + "model": "model1", + "chat_scene": "scene1", + "sub_chat_scene": "subscene1", + "prompt_type": "type1", + "user_name": "user1", + "sys_code": "code1", + } + + def test_save_and_load_storage( + self, sample_storage_prompt_template, in_memory_storage + ): + in_memory_storage.save(sample_storage_prompt_template) + loaded_item = in_memory_storage.load( + sample_storage_prompt_template.identifier, StoragePromptTemplate + ) + assert loaded_item.content == "Sample content, {var1}, {var2}" + + def test_check_exceptions(self): + with pytest.raises(ValueError): + StoragePromptTemplate(prompt_name=None, content="Hello") + + def test_from_object(self, sample_storage_prompt_template): + other = StoragePromptTemplate(prompt_name="other", content="Other content") + sample_storage_prompt_template.from_object(other) + assert sample_storage_prompt_template.content == "Other content" + assert sample_storage_prompt_template.input_variables != "var1,var2" + # Prompt name should not be changed + assert sample_storage_prompt_template.prompt_name == "test_prompt" + assert sample_storage_prompt_template.sys_code == "code1" + + +class TestPromptManager: + def test_save(self, prompt_manager, in_memory_storage): + prompt_template = PromptTemplate( + template="hello {input}", + input_variables=["input"], + template_scene="chat_normal", + ) + prompt_manager.save( + prompt_template, + prompt_name="hello", + ) + result = in_memory_storage.query( + QuerySpec(conditions={"prompt_name": "hello"}), StoragePromptTemplate + ) + assert len(result) == 1 + assert result[0].content == "hello {input}" + + def test_prefer_query_simple(self, prompt_manager, in_memory_storage): + in_memory_storage.save( + StoragePromptTemplate(prompt_name="test_prompt", content="test") + ) + result = prompt_manager.prefer_query("test_prompt") + assert len(result) == 1 + assert result[0].content == "test" + + def test_prefer_query_language(self, prompt_manager, in_memory_storage): + for language in ["en", "zh"]: + in_memory_storage.save( + StoragePromptTemplate( + prompt_name="test_prompt", + content="test", + prompt_language=language, + ) + ) + # Prefer zh, and zh exists, will return zh prompt template + result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh") + assert len(result) == 1 + assert result[0].content == "test" + assert result[0].prompt_language == "zh" + # Prefer language not exists, will return all prompt templates of this name + result = prompt_manager.prefer_query( + "test_prompt", prefer_prompt_language="not_exist" + ) + assert len(result) == 2 + + def test_prefer_query_model(self, prompt_manager, in_memory_storage): + for model in ["model1", "model2"]: + in_memory_storage.save( + StoragePromptTemplate( + prompt_name="test_prompt", content="test", model=model + ) + ) + # Prefer model1, and model1 exists, will return model1 prompt template + result = prompt_manager.prefer_query("test_prompt", prefer_model="model1") + assert len(result) == 1 + assert result[0].content == "test" + assert result[0].model == "model1" + # Prefer model not exists, will return all prompt templates of this name + result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist") + assert len(result) == 2 + + def test_list(self, prompt_manager, in_memory_storage): + prompt_manager.save( + PromptTemplate(template="Hello {name}", input_variables=["name"]), + prompt_name="name1", + ) + prompt_manager.save( + PromptTemplate( + template="Write a SQL of {dialect} to query all data of {table_name}.", + input_variables=["dialect", "table_name"], + ), + prompt_name="sql_template", + ) + all_templates = prompt_manager.list() + assert len(all_templates) == 2 + assert len(prompt_manager.list(prompt_name="name1")) == 1 + assert len(prompt_manager.list(prompt_name="not exist")) == 0 + + def test_delete(self, prompt_manager, in_memory_storage): + prompt_manager.save( + PromptTemplate(template="Hello {name}", input_variables=["name"]), + prompt_name="to_delete", + ) + prompt_manager.delete("to_delete") + result = in_memory_storage.query( + QuerySpec(conditions={"prompt_name": "to_delete"}), StoragePromptTemplate + ) + assert len(result) == 0 diff --git a/dbgpt/core/interface/tests/test_storage.py b/dbgpt/core/interface/tests/test_storage.py index 74864e02f..827e0f028 100644 --- a/dbgpt/core/interface/tests/test_storage.py +++ b/dbgpt/core/interface/tests/test_storage.py @@ -1,10 +1,12 @@ -import pytest from typing import Dict, Type, Union + +import pytest + from dbgpt.core.interface.storage import ( + InMemoryStorage, + QuerySpec, ResourceIdentifier, StorageError, - QuerySpec, - InMemoryStorage, StorageItem, ) from dbgpt.util.serialization.json_serialization import JsonSerializer diff --git a/dbgpt/core/operator/__init__.py b/dbgpt/core/operator/__init__.py new file mode 100644 index 000000000..0d0287d61 --- /dev/null +++ b/dbgpt/core/operator/__init__.py @@ -0,0 +1,31 @@ +from dbgpt.core.interface.operator.llm_operator import ( + BaseLLM, + LLMBranchOperator, + LLMOperator, + RequestBuildOperator, + StreamingLLMOperator, +) +from dbgpt.core.interface.operator.message_operator import ( + BaseConversationOperator, + BufferedConversationMapperOperator, + ConversationMapperOperator, + PostConversationOperator, + PostStreamingConversationOperator, + PreConversationOperator, +) +from dbgpt.core.interface.prompt import PromptTemplateOperator + +__ALL__ = [ + "BaseLLM", + "LLMBranchOperator", + "LLMOperator", + "RequestBuildOperator", + "StreamingLLMOperator", + "BaseConversationOperator", + "BufferedConversationMapperOperator", + "ConversationMapperOperator", + "PostConversationOperator", + "PostStreamingConversationOperator", + "PreConversationOperator", + "PromptTemplateOperator", +] diff --git a/dbgpt/model/__init__.py b/dbgpt/model/__init__.py index 4e317eb06..e13ec4adc 100644 --- a/dbgpt/model/__init__.py +++ b/dbgpt/model/__init__.py @@ -1,4 +1,13 @@ from dbgpt.model.cluster.client import DefaultLLMClient -from dbgpt.model.utils.chatgpt_utils import OpenAILLMClient +from dbgpt.model.utils.chatgpt_utils import ( + OpenAILLMClient, + OpenAIStreamingOperator, + MixinLLMOperator, +) -__ALL__ = ["DefaultLLMClient", "OpenAILLMClient"] +__ALL__ = [ + "DefaultLLMClient", + "OpenAILLMClient", + "OpenAIStreamingOperator", + "MixinLLMOperator", +] diff --git a/dbgpt/model/operator/model_operator.py b/dbgpt/model/operator/model_operator.py index fcb3cadae..9ac82d026 100644 --- a/dbgpt/model/operator/model_operator.py +++ b/dbgpt/model/operator/model_operator.py @@ -171,7 +171,7 @@ def __init__( self._model_task_name = model_task_name self._cache_task_name = cache_task_name - async def branchs(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]: + async def branches(self) -> Dict[BranchFunc[Dict], Union[BaseOperator, str]]: """Defines branch logic based on cache availability. Returns: @@ -233,7 +233,7 @@ async def transform_stream( outputs = [] async for out in input_value: if not llm_cache_key: - llm_cache_key = await self.current_dag_context.get_share_data( + llm_cache_key = await self.current_dag_context.get_from_share_data( _LLM_MODEL_INPUT_VALUE_KEY ) outputs.append(out) @@ -265,7 +265,7 @@ async def map(self, input_value: ModelOutput) -> ModelOutput: Returns: ModelOutput: The same input model output. """ - llm_cache_key: LLMCacheKey = await self.current_dag_context.get_share_data( + llm_cache_key: LLMCacheKey = await self.current_dag_context.get_from_share_data( _LLM_MODEL_INPUT_VALUE_KEY ) llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value) diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index f3aea5a07..32d6e29e0 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -3,11 +3,27 @@ import os import logging from dataclasses import dataclass +from abc import ABC import importlib.metadata as metadata -from typing import List, Dict, Any, Optional, TYPE_CHECKING, Union, AsyncIterator - +from typing import ( + List, + Dict, + Any, + Optional, + TYPE_CHECKING, + Union, + AsyncIterator, + Callable, + Awaitable, +) + +from dbgpt.component import ComponentType +from dbgpt.core.operator import BaseLLM +from dbgpt.core.awel import TransformStreamAbsOperator, BaseOperator from dbgpt.core.interface.llm import ModelMetadata, LLMClient from dbgpt.core.interface.llm import ModelOutput, ModelRequest +from dbgpt.model.cluster.client import DefaultLLMClient +from dbgpt.model.cluster import WorkerManagerFactory if TYPE_CHECKING: import httpx @@ -176,13 +192,13 @@ async def generate_stream( self, request: ModelRequest ) -> AsyncIterator[ModelOutput]: messages = request.to_openai_messages() - payload = self._build_request(request) + payload = self._build_request(request, True) try: chat_completion = await self.client.chat.completions.create( messages=messages, **payload ) text = "" - for r in chat_completion: + async for r in chat_completion: if len(r.choices) == 0: continue if r.choices[0].delta.content is not None: @@ -221,17 +237,74 @@ async def count_token(self, model: str, prompt: str) -> int: raise NotImplementedError() +class OpenAIStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]): + """Transform ModelOutput to openai stream format.""" + + async def transform_stream( + self, input_value: AsyncIterator[ModelOutput] + ) -> AsyncIterator[str]: + async def model_caller() -> str: + """Read model name from share data. + In streaming mode, this transform_stream function will be executed + before parent operator(Streaming Operator is trigger by downstream Operator). + """ + return await self.current_dag_context.get_from_share_data( + BaseLLM.SHARE_DATA_KEY_MODEL_NAME + ) + + async for output in _to_openai_stream(input_value, None, model_caller): + yield output + + +class MixinLLMOperator(BaseLLM, BaseOperator, ABC): + """Mixin class for LLM operator. + + This class extends BaseOperator by adding LLM capabilities. + """ + + def __init__(self, default_client: Optional[LLMClient] = None, **kwargs): + super().__init__(default_client) + self._default_llm_client = default_client + + @property + def llm_client(self) -> LLMClient: + if not self._llm_client: + worker_manager_factory: WorkerManagerFactory = ( + self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, + WorkerManagerFactory, + default_component=None, + ) + ) + if worker_manager_factory: + self._llm_client = DefaultLLMClient(worker_manager_factory.create()) + else: + if self._default_llm_client is None: + from dbgpt.model import OpenAILLMClient + + self._default_llm_client = OpenAILLMClient() + logger.info( + f"Can't find worker manager factory, use default llm client {self._default_llm_client}." + ) + self._llm_client = self._default_llm_client + return self._llm_client + + async def _to_openai_stream( - model: str, output_iter: AsyncIterator[ModelOutput] + output_iter: AsyncIterator[ModelOutput], + model: Optional[str] = None, + model_caller: Callable[[], Union[Awaitable[str], str]] = None, ) -> AsyncIterator[str]: """Convert the output_iter to openai stream format. Args: - model (str): The model name. output_iter (AsyncIterator[ModelOutput]): The output iterator. + model (Optional[str], optional): The model name. Defaults to None. + model_caller (Callable[[None], Union[Awaitable[str], str]], optional): The model caller. Defaults to None. """ import json import shortuuid + import asyncio from fastchat.protocol.openai_api_protocol import ( ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, @@ -245,12 +318,19 @@ async def _to_openai_stream( delta=DeltaMessage(role="assistant"), finish_reason=None, ) - chunk = ChatCompletionStreamResponse(id=id, choices=[choice_data], model=model) + chunk = ChatCompletionStreamResponse( + id=id, choices=[choice_data], model=model or "" + ) yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" previous_text = "" finish_stream_events = [] async for model_output in output_iter: + if model_caller is not None: + if asyncio.iscoroutinefunction(model_caller): + model = await model_caller() + else: + model = model_caller() model_output: ModelOutput = model_output if model_output.error_code != 0: yield f"data: {json.dumps(model_output.to_dict(), ensure_ascii=False)}\n\n" diff --git a/dbgpt/serve/prompt/api/endpoints.py b/dbgpt/serve/prompt/api/endpoints.py index c0dbca3df..c043858fc 100644 --- a/dbgpt/serve/prompt/api/endpoints.py +++ b/dbgpt/serve/prompt/api/endpoints.py @@ -1,15 +1,16 @@ -from typing import Optional, List from functools import cache -from fastapi import APIRouter, Depends, Query, HTTPException -from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import SystemApp from dbgpt.serve.core import Result from dbgpt.util import PaginationResult -from .schemas import ServeRequest, ServerResponse + +from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from ..config import APP_NAME, SERVE_APP_NAME, ServeConfig, SERVE_SERVICE_COMPONENT_NAME +from .schemas import ServeRequest, ServerResponse router = APIRouter() diff --git a/dbgpt/serve/prompt/api/schemas.py b/dbgpt/serve/prompt/api/schemas.py index e5c7610d4..7cdbc7a1b 100644 --- a/dbgpt/serve/prompt/api/schemas.py +++ b/dbgpt/serve/prompt/api/schemas.py @@ -1,6 +1,8 @@ # Define your Pydantic schemas here from typing import Optional + from dbgpt._private.pydantic import BaseModel, Field + from ..config import SERVE_APP_NAME_HUMP diff --git a/dbgpt/serve/prompt/config.py b/dbgpt/serve/prompt/config.py index c304eaf08..2713ce467 100644 --- a/dbgpt/serve/prompt/config.py +++ b/dbgpt/serve/prompt/config.py @@ -1,9 +1,8 @@ -from typing import Optional from dataclasses import dataclass, field +from typing import Optional from dbgpt.serve.core import BaseServeConfig - APP_NAME = "prompt" SERVE_APP_NAME = "dbgpt_serve_prompt" SERVE_APP_NAME_HUMP = "dbgpt_serve_Prompt" diff --git a/dbgpt/serve/prompt/models/models.py b/dbgpt/serve/prompt/models/models.py index 812ff8fae..8095a5014 100644 --- a/dbgpt/serve/prompt/models/models.py +++ b/dbgpt/serve/prompt/models/models.py @@ -1,33 +1,64 @@ """This is an auto-generated model file You can define your own models and DAOs here """ -from typing import Union, Any, Dict from datetime import datetime -from sqlalchemy import Column, Integer, String, Index, Text, DateTime, UniqueConstraint -from dbgpt.storage.metadata import Model, BaseDao, db +from typing import Any, Dict, Union + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint + +from dbgpt.storage.metadata import BaseDao, Model, db + from ..api.schemas import ServeRequest, ServerResponse -from ..config import ServeConfig, SERVER_APP_TABLE_NAME +from ..config import SERVER_APP_TABLE_NAME, ServeConfig class ServeEntity(Model): __tablename__ = "prompt_manage" __table_args__ = ( - UniqueConstraint("prompt_name", "sys_code", name="uk_prompt_name_sys_code"), + UniqueConstraint( + "prompt_name", + "sys_code", + "prompt_language", + "model", + name="uk_prompt_name_sys_code", + ), ) id = Column(Integer, primary_key=True, comment="Auto increment id") - chat_scene = Column(String(100)) - sub_chat_scene = Column(String(100)) - prompt_type = Column(String(100)) - prompt_name = Column(String(512)) - content = Column(Text) - user_name = Column(String(128)) + chat_scene = Column(String(100), comment="Chat scene") + sub_chat_scene = Column(String(100), comment="Sub chat scene") + prompt_type = Column(String(100), comment="Prompt type(eg: common, private)") + prompt_name = Column(String(256), comment="Prompt name") + content = Column(Text, comment="Prompt content") + input_variables = Column( + String(1024), nullable=True, comment="Prompt input variables(split by comma))" + ) + model = Column( + String(128), + nullable=True, + comment="Prompt model name(we can use different models for different prompt", + ) + prompt_language = Column( + String(32), index=True, nullable=True, comment="Prompt language(eg:en, zh-cn)" + ) + prompt_format = Column( + String(32), + index=True, + nullable=True, + default="f-string", + comment="Prompt format(eg: f-string, jinja2)", + ) + user_name = Column(String(128), index=True, nullable=True, comment="User name") sys_code = Column(String(128), index=True, nullable=True, comment="System code") gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") def __repr__(self): - return f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}',user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + return ( + f"ServeEntity(id={self.id}, chat_scene='{self.chat_scene}', sub_chat_scene='{self.sub_chat_scene}', " + f"prompt_type='{self.prompt_type}', prompt_name='{self.prompt_name}', content='{self.content}'," + f"user_name='{self.user_name}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + ) class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): diff --git a/dbgpt/serve/prompt/models/prompt_template_adapter.py b/dbgpt/serve/prompt/models/prompt_template_adapter.py new file mode 100644 index 000000000..59a7c5ba2 --- /dev/null +++ b/dbgpt/serve/prompt/models/prompt_template_adapter.py @@ -0,0 +1,56 @@ +from typing import Type + +from sqlalchemy.orm import Session + +from dbgpt.core.interface.prompt import PromptTemplateIdentifier, StoragePromptTemplate +from dbgpt.core.interface.storage import StorageItemAdapter + +from .models import ServeEntity + + +class PromptTemplateAdapter(StorageItemAdapter[StoragePromptTemplate, ServeEntity]): + def to_storage_format(self, item: StoragePromptTemplate) -> ServeEntity: + return ServeEntity( + chat_scene=item.chat_scene, + sub_chat_scene=item.sub_chat_scene, + prompt_type=item.prompt_type, + prompt_name=item.prompt_name, + content=item.content, + input_variables=item.input_variables, + model=item.model, + prompt_language=item.prompt_language, + prompt_format=item.prompt_format, + user_name=item.user_name, + sys_code=item.sys_code, + ) + + def from_storage_format(self, model: ServeEntity) -> StoragePromptTemplate: + return StoragePromptTemplate( + chat_scene=model.chat_scene, + sub_chat_scene=model.sub_chat_scene, + prompt_type=model.prompt_type, + prompt_name=model.prompt_name, + content=model.content, + input_variables=model.input_variables, + model=model.model, + prompt_language=model.prompt_language, + prompt_format=model.prompt_format, + user_name=model.user_name, + sys_code=model.sys_code, + ) + + def get_query_for_identifier( + self, + storage_format: Type[ServeEntity], + resource_id: PromptTemplateIdentifier, + **kwargs, + ): + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + query_obj = session.query(ServeEntity) + for key, value in resource_id.to_dict().items(): + if value is None: + continue + query_obj = query_obj.filter(getattr(ServeEntity, key) == value) + return query_obj diff --git a/dbgpt/serve/prompt/serve.py b/dbgpt/serve/prompt/serve.py index 74db1a996..015450523 100644 --- a/dbgpt/serve/prompt/serve.py +++ b/dbgpt/serve/prompt/serve.py @@ -1,17 +1,80 @@ -from typing import List, Optional +import logging +from typing import List, Optional, Union + +from sqlalchemy import URL + from dbgpt.component import BaseComponent, SystemApp +from dbgpt.core import PromptManager -from .api.endpoints import router, init_endpoints +from ...storage.metadata import DatabaseManager +from .api.endpoints import init_endpoints, router from .config import ( + APP_NAME, SERVE_APP_NAME, SERVE_APP_NAME_HUMP, - APP_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig, ) +from .models.prompt_template_adapter import PromptTemplateAdapter + +logger = logging.getLogger(__name__) class Serve(BaseComponent): + """Serve component + + Examples: + + Register the serve component to the system app + + .. code-block:: python + + from fastapi import FastAPI + from dbgpt import SystemApp + from dbgpt.core import PromptTemplate + from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME + + app = FastAPI() + system_app = SystemApp(app) + system_app.register(Serve, api_prefix="/api/v1/prompt") + # Run before start hook + system_app.before_start() + + prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve) + + # Get the prompt manager + prompt_manager = prompt_serve.prompt_manager + prompt_manager.save( + PromptTemplate(template="Hello {name}", input_variables=["name"]), + prompt_name="prompt_name", + ) + + With your database url + + .. code-block:: python + + from fastapi import FastAPI + from dbgpt import SystemApp + from dbgpt.core import PromptTemplate + from dbgpt.serve.prompt.serve import Serve, SERVE_APP_NAME + + app = FastAPI() + system_app = SystemApp(app) + system_app.register(Serve, api_prefix="/api/v1/prompt", db_url_or_db="sqlite:///:memory:", try_create_tables=True) + # Run before start hook + system_app.before_start() + + prompt_serve = system_app.get_component(SERVE_APP_NAME, Serve) + + # Get the prompt manager + prompt_manager = prompt_serve.prompt_manager + prompt_manager.save( + PromptTemplate(template="Hello {name}", input_variables=["name"]), + prompt_name="prompt_name", + ) + + """ + name = SERVE_APP_NAME def __init__( @@ -19,12 +82,17 @@ def __init__( system_app: SystemApp, api_prefix: Optional[str] = f"/api/v1/serve/{APP_NAME}", tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, ): if tags is None: tags = [SERVE_APP_NAME_HUMP] self._system_app = None self._api_prefix = api_prefix self._tags = tags + self._prompt_manager = None + self._db_url_or_db = db_url_or_db + self._try_create_tables = try_create_tables def init_app(self, system_app: SystemApp): self._system_app = system_app @@ -33,10 +101,37 @@ def init_app(self, system_app: SystemApp): ) init_endpoints(self._system_app) + @property + def prompt_manager(self) -> PromptManager: + """Get the prompt manager of the serve app with db storage""" + return self._prompt_manager + def before_start(self): """Called before the start of the application. You can do some initialization here. """ # import your own module here to ensure the module is loaded before the application starts + from dbgpt.core.interface.prompt import PromptManager + from dbgpt.storage.metadata import Model, db + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + from .models.models import ServeEntity + + init_db = self._db_url_or_db or db + init_db = DatabaseManager.build_from(init_db, base=Model) + if self._try_create_tables: + try: + init_db.create_all() + except Exception as e: + logger.warning(f"Failed to create tables: {e}") + storage_adapter = PromptTemplateAdapter() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + init_db, + ServeEntity, + storage_adapter, + serializer, + ) + self._prompt_manager = PromptManager(storage) diff --git a/dbgpt/serve/prompt/service/service.py b/dbgpt/serve/prompt/service/service.py index ce90f7f3c..037b8cc3c 100644 --- a/dbgpt/serve/prompt/service/service.py +++ b/dbgpt/serve/prompt/service/service.py @@ -1,11 +1,13 @@ -from typing import Optional, List +from typing import List, Optional + from dbgpt.component import BaseComponent, SystemApp +from dbgpt.serve.core import BaseService from dbgpt.storage.metadata import BaseDao from dbgpt.util.pagination_utils import PaginationResult -from dbgpt.serve.core import BaseService -from ..models.models import ServeDao, ServeEntity + from ..api.schemas import ServeRequest, ServerResponse -from ..config import SERVE_SERVICE_COMPONENT_NAME, SERVE_CONFIG_KEY_PREFIX, ServeConfig +from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..models.models import ServeDao, ServeEntity class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): diff --git a/dbgpt/serve/prompt/tests/test_endpoints.py b/dbgpt/serve/prompt/tests/test_endpoints.py index e45a56745..701b43523 100644 --- a/dbgpt/serve/prompt/tests/test_endpoints.py +++ b/dbgpt/serve/prompt/tests/test_endpoints.py @@ -1,15 +1,15 @@ import pytest +from fastapi import FastAPI from httpx import AsyncClient -from fastapi import FastAPI from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import asystem_app, client from dbgpt.storage.metadata import db from dbgpt.util import PaginationResult -from ..config import SERVE_CONFIG_KEY_PREFIX -from ..api.endpoints import router, init_endpoints -from ..api.schemas import ServeRequest, ServerResponse -from dbgpt.serve.core.tests.conftest import client, asystem_app +from ..api.endpoints import init_endpoints, router +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX @pytest.fixture(autouse=True) diff --git a/dbgpt/serve/prompt/tests/test_models.py b/dbgpt/serve/prompt/tests/test_models.py index c5ef89f8f..744412ec5 100644 --- a/dbgpt/serve/prompt/tests/test_models.py +++ b/dbgpt/serve/prompt/tests/test_models.py @@ -1,9 +1,12 @@ from typing import List + import pytest + from dbgpt.storage.metadata import db -from ..config import ServeConfig + from ..api.schemas import ServeRequest, ServerResponse -from ..models.models import ServeEntity, ServeDao +from ..config import ServeConfig +from ..models.models import ServeDao, ServeEntity @pytest.fixture(autouse=True) @@ -34,6 +37,8 @@ def default_entity_dict(): "content": "Write a qsort function in python.", "user_name": "zhangsan", "sys_code": "dbgpt", + "prompt_language": "zh", + "model": "vicuna-13b-v1.5", } @@ -60,7 +65,14 @@ def test_entity_create(default_entity_dict): def test_entity_unique_key(default_entity_dict): ServeEntity.create(**default_entity_dict) with pytest.raises(Exception): - ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}) + ServeEntity.create( + **{ + "prompt_name": "my_prompt_1", + "sys_code": "dbgpt", + "prompt_language": "zh", + "model": "vicuna-13b-v1.5", + } + ) def test_entity_get(default_entity_dict): diff --git a/dbgpt/serve/prompt/tests/test_prompt_template_adapter.py b/dbgpt/serve/prompt/tests/test_prompt_template_adapter.py new file mode 100644 index 000000000..7d9911984 --- /dev/null +++ b/dbgpt/serve/prompt/tests/test_prompt_template_adapter.py @@ -0,0 +1,144 @@ +import pytest + +from dbgpt.core.interface.prompt import PromptManager, PromptTemplate +from dbgpt.storage.metadata import db +from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage +from dbgpt.util.serialization.json_serialization import JsonSerializer + +from ..models.prompt_template_adapter import PromptTemplateAdapter, ServeEntity + + +@pytest.fixture +def serializer(): + return JsonSerializer() + + +@pytest.fixture +def db_url(): + """Use in-memory SQLite database for testing""" + return "sqlite:///:memory:" + + +@pytest.fixture +def db_manager(db_url): + db.init_db(db_url) + db.create_all() + return db + + +@pytest.fixture +def storage_adapter(): + return PromptTemplateAdapter() + + +@pytest.fixture +def storage(db_manager, serializer, storage_adapter): + storage = SQLAlchemyStorage( + db_manager, + ServeEntity, + storage_adapter, + serializer, + ) + return storage + + +@pytest.fixture +def prompt_manager(storage): + return PromptManager(storage) + + +def test_save(prompt_manager: PromptManager): + prompt_template = PromptTemplate( + template="hello {input}", + input_variables=["input"], + template_scene="chat_normal", + ) + prompt_manager.save( + prompt_template, + prompt_name="hello", + ) + + with db.session() as session: + # Query from database + result = ( + session.query(ServeEntity).filter(ServeEntity.prompt_name == "hello").all() + ) + assert len(result) == 1 + assert result[0].prompt_name == "hello" + assert result[0].content == "hello {input}" + assert result[0].input_variables == "input" + with db.session() as session: + assert session.query(ServeEntity).count() == 1 + assert ( + session.query(ServeEntity) + .filter(ServeEntity.prompt_name == "not exist prompt name") + .count() + == 0 + ) + + +def test_prefer_query_language(prompt_manager: PromptManager): + for language in ["en", "zh"]: + prompt_template = PromptTemplate( + template="test", + input_variables=[], + template_scene="chat_normal", + ) + prompt_manager.save( + prompt_template, + prompt_name="test_prompt", + prompt_language=language, + ) + # Prefer zh, and zh exists, will return zh prompt template + result = prompt_manager.prefer_query("test_prompt", prefer_prompt_language="zh") + assert len(result) == 1 + assert result[0].content == "test" + assert result[0].prompt_language == "zh" + # Prefer language not exists, will return all prompt templates of this name + result = prompt_manager.prefer_query( + "test_prompt", prefer_prompt_language="not_exist" + ) + assert len(result) == 2 + + +def test_prefer_query_model(prompt_manager: PromptManager): + for model in ["model1", "model2"]: + prompt_template = PromptTemplate( + template="test", + input_variables=[], + template_scene="chat_normal", + ) + prompt_manager.save( + prompt_template, + prompt_name="test_prompt", + model=model, + ) + # Prefer model1, and model1 exists, will return model1 prompt template + result = prompt_manager.prefer_query("test_prompt", prefer_model="model1") + assert len(result) == 1 + assert result[0].content == "test" + assert result[0].model == "model1" + # Prefer model not exists, will return all prompt templates of this name + result = prompt_manager.prefer_query("test_prompt", prefer_model="not_exist") + assert len(result) == 2 + + +def test_list(prompt_manager: PromptManager): + for i in range(10): + prompt_template = PromptTemplate( + template="test", + input_variables=[], + template_scene="chat_normal", + ) + prompt_manager.save( + prompt_template, + prompt_name=f"test_prompt_{i}", + sys_code="dbgpt" if i % 2 == 0 else "not_dbgpt", + ) + # Test list all + result = prompt_manager.list() + assert len(result) == 10 + + for i in range(10): + assert len(prompt_manager.list(prompt_name=f"test_prompt_{i}")) == 1 + assert len(prompt_manager.list(sys_code="dbgpt")) == 5 diff --git a/dbgpt/serve/prompt/tests/test_service.py b/dbgpt/serve/prompt/tests/test_service.py index a7c9cc686..3992e89cb 100644 --- a/dbgpt/serve/prompt/tests/test_service.py +++ b/dbgpt/serve/prompt/tests/test_service.py @@ -1,11 +1,13 @@ from typing import List + import pytest + from dbgpt.component import SystemApp -from dbgpt.storage.metadata import db from dbgpt.serve.core.tests.conftest import system_app +from dbgpt.storage.metadata import db -from ..models.models import ServeEntity from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity from ..service.service import Service diff --git a/dbgpt/storage/metadata/db_manager.py b/dbgpt/storage/metadata/db_manager.py index 782fcf8ab..0876dd491 100644 --- a/dbgpt/storage/metadata/db_manager.py +++ b/dbgpt/storage/metadata/db_manager.py @@ -236,6 +236,7 @@ def init_db( engine_args: Optional[Dict] = None, base: Optional[DeclarativeMeta] = None, query_class=BaseQuery, + override_query_class: Optional[bool] = False, ): """Initialize the database manager. @@ -244,15 +245,16 @@ def init_db( engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. base (Optional[DeclarativeMeta]): The base class. Defaults to None. query_class (BaseQuery, optional): The query class. Defaults to BaseQuery. + override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False. """ self._db_url = db_url if query_class is not None: self.Query = query_class if base is not None: self._base = base - if not hasattr(base, "query"): + if not hasattr(base, "query") or override_query_class: base.query = _QueryObject(self) - if not getattr(base, "query_class", None): + if not getattr(base, "query_class", None) or override_query_class: base.query_class = self.Query self._engine = create_engine(db_url, **(engine_args or {})) session_factory = sessionmaker(bind=self._engine) @@ -299,6 +301,59 @@ def init_default_db( def create_all(self): self.Model.metadata.create_all(self._engine) + @staticmethod + def build_from( + db_url_or_db: Union[str, URL, DatabaseManager], + engine_args: Optional[Dict] = None, + base: Optional[DeclarativeMeta] = None, + query_class=BaseQuery, + override_query_class: Optional[bool] = False, + ) -> DatabaseManager: + """Build the database manager from the db_url_or_db. + + Examples: + + Build from the database url. + + .. code-block:: python + + from dbgpt.storage.metadata import DatabaseManager + from sqlalchemy import Column, Integer, String + db = DatabaseManager.build_from("sqlite:///:memory:") + class User(db.Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + fullname = Column(String(50)) + db.create_all() + with db.session() as session: + session.add(User(name="test", fullname="test")) + session.commit() + print(User.query.filter(User.name == "test").all()) + + Args: + db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager. + engine_args (Optional[Dict], optional): The engine arguments. Defaults to None. + base (Optional[DeclarativeMeta]): The base class. Defaults to None. + query_class (BaseQuery, optional): The query class. Defaults to BaseQuery. + override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False. + + Returns: + DatabaseManager: The database manager. + """ + if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL): + db_manager = DatabaseManager() + db_manager.init_db( + db_url_or_db, engine_args, base, query_class, override_query_class + ) + return db_manager + elif isinstance(db_url_or_db, DatabaseManager): + return db_url_or_db + else: + raise ValueError( + f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}" + ) + db = DatabaseManager() """The global database manager. @@ -375,14 +430,21 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]: class CRUDMixin(BaseCRUDMixin[T], Generic[T]): """Mixin that adds convenience methods for CRUD (create, read, update, delete)""" + _db_manager: DatabaseManager = db_manager + + @classmethod + def set_db_manager(cls, db_manager: DatabaseManager): + # TODO: It is hard to replace to user DB Connection + cls._db_manager = db_manager + @classmethod def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]: """Get a record by its primary key identifier.""" - return db_manager._session().get(cls, ident) + return cls._db_manager._session().get(cls, ident) def save(self: T, commit: Optional[bool] = True) -> T: """Save the record.""" - session = db_manager._session() + session = self._db_manager._session() session.add(self) if commit: session.commit() @@ -390,7 +452,7 @@ def save(self: T, commit: Optional[bool] = True) -> T: def delete(self: T, commit: Optional[bool] = True) -> None: """Remove the record from the database.""" - session = db_manager._session() + session = self._db_manager._session() session.delete(self) return commit and session.commit() diff --git a/dbgpt/storage/metadata/db_storage.py b/dbgpt/storage/metadata/db_storage.py index 2cc1a5118..daa87ceba 100644 --- a/dbgpt/storage/metadata/db_storage.py +++ b/dbgpt/storage/metadata/db_storage.py @@ -34,16 +34,9 @@ def __init__( query_class=BaseQuery, ): super().__init__(serializer=serializer, adapter=adapter) - if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL): - db_manager = DatabaseManager() - db_manager.init_db(db_url_or_db, engine_args, base, query_class) - self.db_manager = db_manager - elif isinstance(db_url_or_db, DatabaseManager): - self.db_manager = db_url_or_db - else: - raise ValueError( - f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}" - ) + self.db_manager = DatabaseManager.build_from( + db_url_or_db, engine_args, base, query_class + ) self._model_class = model_class @contextmanager diff --git a/dbgpt/storage/metadata/tests/test_db_manager.py b/dbgpt/storage/metadata/tests/test_db_manager.py index 645f6271a..a6ad24caa 100644 --- a/dbgpt/storage/metadata/tests/test_db_manager.py +++ b/dbgpt/storage/metadata/tests/test_db_manager.py @@ -1,5 +1,6 @@ from __future__ import annotations import pytest +import tempfile from typing import Type from dbgpt.storage.metadata.db_manager import ( DatabaseManager, @@ -103,7 +104,6 @@ class User(Model): db.create_all() - # 添加数据 with db.session() as session: for i in range(30): user = User(name=f"User {i}") @@ -127,3 +127,29 @@ class User(Model): User.query.paginate_query(page=0, per_page=10) with pytest.raises(ValueError): User.query.paginate_query(page=1, per_page=-1) + + +def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]): + assert db.metadata.tables == {} + + class User(Model): + __tablename__ = "user" + id = Column(Integer, primary_key=True) + name = Column(String(50)) + + with tempfile.NamedTemporaryFile(delete=True) as db_file: + filename = db_file.name + new_db = DatabaseManager.build_from( + f"sqlite:///{filename}", base=Model, override_query_class=True + ) + Model.set_db_manager(new_db) + new_db.create_all() + db.create_all() + assert list(new_db.metadata.tables.keys())[0] == "user" + User.create(**{"name": "John Doe"}) + with new_db.session() as session: + assert session.query(User).filter_by(name="John Doe").first() is not None + with db.session() as session: + assert session.query(User).filter_by(name="John Doe").first() is None + assert len(User.query.all()) == 1 + assert User.query.filter(User.name == "John Doe").first().name == "John Doe" diff --git a/examples/awel/simple_chat_dag_example.py b/examples/awel/simple_chat_dag_example.py index cd6b8b12a..50a26fb0e 100644 --- a/examples/awel/simple_chat_dag_example.py +++ b/examples/awel/simple_chat_dag_example.py @@ -6,7 +6,8 @@ .. code-block:: shell - curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_chat \ + DBGPT_SERVER="http://127.0.0.1:5000" + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_chat \ -H "Content-Type: application/json" -d '{ "model": "proxyllm", "user_input": "hello" @@ -52,3 +53,14 @@ async def map(self, input_value: TriggerReqBody) -> Dict: # type(out) == ModelOutput model_parse_task = MapOperator(lambda out: out.to_dict()) trigger >> request_handle_task >> model_task >> model_parse_task + + +if __name__ == "__main__": + if dag.leaf_nodes[0].dev_mode: + # Development mode, you can run the dag locally for debugging. + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([dag], port=5555) + else: + # Production mode, DB-GPT will automatically load and execute the current file after startup. + pass diff --git a/examples/awel/simple_chat_history_example.py b/examples/awel/simple_chat_history_example.py new file mode 100644 index 000000000..f0c1a6aa5 --- /dev/null +++ b/examples/awel/simple_chat_history_example.py @@ -0,0 +1,186 @@ +"""AWEL: Simple chat with history example + + DB-GPT will automatically load and execute the current file after startup. + + Examples: + + Call with non-streaming response. + .. code-block:: shell + + DBGPT_SERVER="http://127.0.0.1:5000" + MODEL="gpt-3.5-turbo" + # Fist round + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ + -H "Content-Type: application/json" -d '{ + "model": "gpt-3.5-turbo", + "context": { + "conv_uid": "uuid_conv_1234" + }, + "messages": "Who is elon musk?" + }' + + # Second round + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ + -H "Content-Type: application/json" -d '{ + "model": "gpt-3.5-turbo", + "context": { + "conv_uid": "uuid_conv_1234" + }, + "messages": "Is he rich?" + }' + + Call with streaming response. + .. code-block:: shell + + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ + -H "Content-Type: application/json" -d '{ + "model": "gpt-3.5-turbo", + "context": { + "conv_uid": "uuid_conv_stream_1234" + }, + "stream": true, + "messages": "Who is elon musk?" + }' + + # Second round + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \ + -H "Content-Type: application/json" -d '{ + "model": "gpt-3.5-turbo", + "context": { + "conv_uid": "uuid_conv_stream_1234" + }, + "stream": true, + "messages": "Is he rich?" + }' + + +""" +from typing import Dict, Any, Optional, Union, List +import logging +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core.awel import ( + DAG, + HttpTrigger, + MapOperator, + JoinOperator, +) +from dbgpt.core import LLMClient, InMemoryStorage +from dbgpt.core.operator import ( + LLMBranchOperator, + LLMOperator, + StreamingLLMOperator, + RequestBuildOperator, + PreConversationOperator, + PostConversationOperator, + PostStreamingConversationOperator, + BufferedConversationMapperOperator, +) +from dbgpt.model import OpenAIStreamingOperator, MixinLLMOperator + +logger = logging.getLogger(__name__) + + +class ReqContext(BaseModel): + user_name: Optional[str] = Field( + None, description="The user name of the model request." + ) + + sys_code: Optional[str] = Field( + None, description="The system code of the model request." + ) + conv_uid: Optional[str] = Field( + None, description="The conversation uid of the model request." + ) + + +class TriggerReqBody(BaseModel): + messages: Union[str, List[Dict[str, str]]] = Field( + ..., description="User input messages" + ) + model: str = Field(..., description="Model name") + stream: Optional[bool] = Field(default=False, description="Whether return stream") + context: Optional[ReqContext] = Field( + default=None, description="The context of the model request." + ) + + +class MyLLMOperator(MixinLLMOperator, LLMOperator): + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + LLMOperator.__init__(self, llm_client, **kwargs) + + +class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator): + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + StreamingLLMOperator.__init__(self, llm_client, **kwargs) + + +with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag: + # Receive http request and trigger dag to run. + trigger = HttpTrigger( + "/examples/simple_history/multi_round/chat/completions", + methods="POST", + request_body=TriggerReqBody, + streaming_predict_func=lambda req: req.stream, + ) + # Transform request body to model request. + request_handle_task = RequestBuildOperator() + # Pre-process conversation, use InMemoryStorage to store conversation. + pre_conversation_task = PreConversationOperator( + storage=InMemoryStorage(), message_storage=InMemoryStorage() + ) + # Keep last k round conversation. + history_conversation_task = BufferedConversationMapperOperator(last_k_round=5) + + # Save conversation to storage. + post_conversation_task = PostConversationOperator() + # Save streaming conversation to storage. + post_streaming_conversation_task = PostStreamingConversationOperator() + + # Use LLMOperator to generate response. + llm_task = MyLLMOperator(task_name="llm_task") + streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") + branch_task = LLMBranchOperator( + stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" + ) + model_parse_task = MapOperator(lambda out: out.to_dict()) + openai_format_stream_task = OpenAIStreamingOperator() + result_join_task = JoinOperator( + combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out + ) + + ( + trigger + >> request_handle_task + >> pre_conversation_task + >> history_conversation_task + >> branch_task + ) + + # The branch of no streaming response. + ( + branch_task + >> llm_task + >> post_conversation_task + >> model_parse_task + >> result_join_task + ) + # The branch of streaming response. + ( + branch_task + >> streaming_llm_task + >> post_streaming_conversation_task + >> openai_format_stream_task + >> result_join_task + ) + +if __name__ == "__main__": + if multi_round_dag.leaf_nodes[0].dev_mode: + # Development mode, you can run the dag locally for debugging. + from dbgpt.core.awel import setup_dev_environment + + setup_dev_environment([multi_round_dag], port=5555) + else: + # Production mode, DB-GPT will automatically load and execute the current file after startup. + pass diff --git a/examples/awel/simple_llm_client_example.py b/examples/awel/simple_llm_client_example.py index 2d12d0eac..6394c41ce 100644 --- a/examples/awel/simple_llm_client_example.py +++ b/examples/awel/simple_llm_client_example.py @@ -2,45 +2,58 @@ DB-GPT will automatically load and execute the current file after startup. - Example: - - .. code-block:: shell - - DBGPT_SERVER="http://127.0.0.1:5000" - curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate \ - -H "Content-Type: application/json" -d '{ - "model": "proxyllm", - "messages": "hello" - }' - - curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/generate_stream \ - -H "Content-Type: application/json" -d '{ - "model": "proxyllm", - "messages": "hello", - "stream": true - }' - - curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \ - -H "Content-Type: application/json" -d '{ - "model": "proxyllm", - "messages": "hello" - }' + Examples: + + Call with non-streaming response. + .. code-block:: shell + + DBGPT_SERVER="http://127.0.0.1:5000" + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \ + -H "Content-Type: application/json" -d '{ + "model": "proxyllm", + "messages": "hello" + }' + + Call with streaming response. + .. code-block:: shell + + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \ + -H "Content-Type: application/json" -d '{ + "model": "proxyllm", + "messages": "hello", + "stream": true + }' + + Call model and count token. + .. code-block:: shell + + curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \ + -H "Content-Type: application/json" -d '{ + "model": "proxyllm", + "messages": "hello" + }' """ -from typing import Dict, Any, AsyncIterator, Optional, Union, List +from typing import Dict, Any, Optional, Union, List +import logging from dbgpt._private.pydantic import BaseModel, Field -from dbgpt.component import ComponentType -from dbgpt.core.awel import DAG, HttpTrigger, MapOperator, TransformStreamAbsOperator -from dbgpt.core import ( - ModelMessage, - LLMClient, +from dbgpt.core.awel import ( + DAG, + HttpTrigger, + MapOperator, + JoinOperator, +) +from dbgpt.core import LLMClient + +from dbgpt.core.operator import ( + LLMBranchOperator, LLMOperator, StreamingLLMOperator, - ModelOutput, - ModelRequest, + RequestBuildOperator, ) -from dbgpt.model import DefaultLLMClient -from dbgpt.model.cluster import WorkerManagerFactory +from dbgpt.model import OpenAIStreamingOperator, MixinLLMOperator + +logger = logging.getLogger(__name__) class TriggerReqBody(BaseModel): @@ -51,58 +64,24 @@ class TriggerReqBody(BaseModel): stream: Optional[bool] = Field(default=False, description="Whether return stream") -class RequestHandleOperator(MapOperator[TriggerReqBody, ModelRequest]): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def map(self, input_value: TriggerReqBody) -> ModelRequest: - messages = [ModelMessage.build_human_message(input_value.messages)] - await self.current_dag_context.save_to_share_data( - "request_model_name", input_value.model - ) - return ModelRequest( - model=input_value.model, - messages=messages, - echo=False, - ) - - -class LLMMixin: - @property - def llm_client(self) -> LLMClient: - if not self._llm_client: - worker_manager = self.system_app.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - self._llm_client = DefaultLLMClient(worker_manager) - return self._llm_client - - -class MyLLMOperator(LLMMixin, LLMOperator): - def __init__(self, llm_client: LLMClient = None, **kwargs): - super().__init__(llm_client, **kwargs) +class MyLLMOperator(MixinLLMOperator, LLMOperator): + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + LLMOperator.__init__(self, llm_client, **kwargs) -class MyStreamingLLMOperator(LLMMixin, StreamingLLMOperator): - def __init__(self, llm_client: LLMClient = None, **kwargs): - super().__init__(llm_client, **kwargs) +class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator): + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + StreamingLLMOperator.__init__(self, llm_client, **kwargs) -class MyLLMStreamingOperator(TransformStreamAbsOperator[ModelOutput, str]): - async def transform_stream( - self, input_value: AsyncIterator[ModelOutput] - ) -> AsyncIterator[str]: - from dbgpt.model.utils.chatgpt_utils import _to_openai_stream - - model = await self.current_dag_context.get_share_data("request_model_name") - async for output in _to_openai_stream(model, input_value): - yield output - - -class MyModelToolOperator(LLMMixin, MapOperator[TriggerReqBody, Dict[str, Any]]): - def __init__(self, llm_client: LLMClient = None, **kwargs): - self._llm_client = llm_client - MapOperator.__init__(self, **kwargs) +class MyModelToolOperator( + MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]] +): + def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + super().__init__(llm_client) + MapOperator.__init__(self, llm_client, **kwargs) async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: prompt_tokens = await self.llm_client.count_token( @@ -118,25 +97,27 @@ async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag: # Receive http request and trigger dag to run. trigger = HttpTrigger( - "/examples/simple_client/generate", methods="POST", request_body=TriggerReqBody - ) - request_handle_task = RequestHandleOperator() - model_task = MyLLMOperator() - model_parse_task = MapOperator(lambda out: out.to_dict()) - trigger >> request_handle_task >> model_task >> model_parse_task - -with DAG("dbgpt_awel_simple_llm_client_generate_stream") as client_generate_stream_dag: - # Receive http request and trigger dag to run. - trigger = HttpTrigger( - "/examples/simple_client/generate_stream", + "/examples/simple_client/chat/completions", methods="POST", request_body=TriggerReqBody, - streaming_response=True, + streaming_predict_func=lambda req: req.stream, ) - request_handle_task = RequestHandleOperator() - model_task = MyStreamingLLMOperator() - openai_format_stream_task = MyLLMStreamingOperator() - trigger >> request_handle_task >> model_task >> openai_format_stream_task + request_handle_task = RequestBuildOperator() + llm_task = MyLLMOperator(task_name="llm_task") + streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task") + branch_task = LLMBranchOperator( + stream_task_name="streaming_llm_task", no_stream_task_name="llm_task" + ) + model_parse_task = MapOperator(lambda out: out.to_dict()) + openai_format_stream_task = OpenAIStreamingOperator() + result_join_task = JoinOperator( + combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out + ) + + trigger >> request_handle_task >> branch_task + branch_task >> llm_task >> model_parse_task >> result_join_task + branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task + with DAG("dbgpt_awel_simple_llm_client_count_token") as client_count_token_dag: # Receive http request and trigger dag to run. @@ -147,3 +128,15 @@ async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]: ) model_task = MyModelToolOperator() trigger >> model_task + + +if __name__ == "__main__": + if client_generate_dag.leaf_nodes[0].dev_mode: + # Development mode, you can run the dag locally for debugging. + from dbgpt.core.awel import setup_dev_environment + + dags = [client_generate_dag, client_count_token_dag] + setup_dev_environment(dags, port=5555) + else: + # Production mode, DB-GPT will automatically load and execute the current file after startup. + pass diff --git a/examples/sdk/simple_sdk_llm_example.py b/examples/sdk/simple_sdk_llm_example.py index a33c1ceec..40f1321cf 100644 --- a/examples/sdk/simple_sdk_llm_example.py +++ b/examples/sdk/simple_sdk_llm_example.py @@ -2,9 +2,11 @@ from dbgpt.core.awel import DAG from dbgpt.core import ( BaseOutputParser, - RequestBuildOperator, PromptTemplate, +) +from dbgpt.core.operator import ( LLMOperator, + RequestBuildOperator, ) from dbgpt.model import OpenAILLMClient diff --git a/examples/sdk/simple_sdk_llm_sql_example.py b/examples/sdk/simple_sdk_llm_sql_example.py index 249ae57a6..fcf0f9206 100644 --- a/examples/sdk/simple_sdk_llm_sql_example.py +++ b/examples/sdk/simple_sdk_llm_sql_example.py @@ -10,9 +10,11 @@ ) from dbgpt.core import ( SQLOutputParser, + PromptTemplate, +) +from dbgpt.core.operator import ( LLMOperator, RequestBuildOperator, - PromptTemplate, ) from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect from dbgpt.datasource.operator.datasource_operator import DatasourceOperator