diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index f48feceb0..10f738d3b 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -171,13 +171,13 @@ def llm_client(self) -> LLMClient: async def call_llm_operator(self, request: ModelRequest) -> ModelOutput: llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP) - return await llm_task.call(call_data={"data": request}) + return await llm_task.call(call_data=request) async def call_streaming_operator( self, request: ModelRequest ) -> AsyncIterator[ModelOutput]: llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP) - async for out in await llm_task.call_stream(call_data={"data": request}): + async for out in await llm_task.call_stream(call_data=request): yield out def do_action(self, prompt_response): @@ -251,11 +251,9 @@ async def _build_model_request(self) -> ModelRequest: str_history=self.prompt_template.str_history, request_context=req_ctx, ) - node_input = { - "data": ChatComposerInput( - messages=self.history_messages, prompt_dict=input_values - ) - } + node_input = ChatComposerInput( + messages=self.history_messages, prompt_dict=input_values + ) # llm_messages = self.generate_llm_messages() model_request: ModelRequest = await node.call(call_data=node_input) model_request.context.cache_enable = self.model_cache_enable diff --git a/dbgpt/app/scene/operator/app_operator.py b/dbgpt/app/scene/operator/app_operator.py index c9e7f03b6..91f65e357 100644 --- a/dbgpt/app/scene/operator/app_operator.py +++ b/dbgpt/app/scene/operator/app_operator.py @@ -87,7 +87,7 @@ async def map(self, input_value: ChatComposerInput) -> ModelRequest: end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0] # Sub dag, use the same dag context in the parent dag messages = await end_node.call( - call_data={"data": input_value}, dag_ctx=self.current_dag_context + call_data=input_value, dag_ctx=self.current_dag_context ) span_id = self._request_context.span_id model_request = ModelRequest.build_request( diff --git a/dbgpt/core/awel/__init__.py b/dbgpt/core/awel/__init__.py index 84e2fc775..c97ef1186 100644 --- a/dbgpt/core/awel/__init__.py +++ b/dbgpt/core/awel/__init__.py @@ -22,6 +22,7 @@ JoinOperator, MapOperator, ReduceStreamOperator, + TriggerOperator, ) from .operator.stream_operator import ( StreamifyAbsOperator, @@ -50,6 +51,7 @@ "BaseOperator", "JoinOperator", "ReduceStreamOperator", + "TriggerOperator", "MapOperator", "BranchOperator", "InputOperator", @@ -150,4 +152,6 @@ def setup_dev_environment( for trigger in dag.trigger_nodes: trigger_manager.register_trigger(trigger) trigger_manager.after_register() - uvicorn.run(app, host=host, port=port) + if trigger_manager.keep_running(): + # Should keep running + uvicorn.run(app, host=host, port=port) diff --git a/dbgpt/core/awel/operator/base.py b/dbgpt/core/awel/operator/base.py index 0af890e7c..14aa1fb33 100644 --- a/dbgpt/core/awel/operator/base.py +++ b/dbgpt/core/awel/operator/base.py @@ -28,7 +28,7 @@ F = TypeVar("F", bound=FunctionType) -CALL_DATA = Union[Dict, Dict[str, Dict]] +CALL_DATA = Union[Dict[str, Any], Any] class WorkflowRunner(ABC, Generic[T]): @@ -197,6 +197,8 @@ async def call( Returns: OUT: The output of the node after execution. """ + if call_data: + call_data = {"data": call_data} out_ctx = await self._runner.execute_workflow( self, call_data, exist_dag_ctx=dag_ctx ) @@ -242,6 +244,8 @@ async def call_stream( Returns: AsyncIterator[OUT]: An asynchronous iterator over the output stream. """ + if call_data: + call_data = {"data": call_data} out_ctx = await self._runner.execute_workflow( self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx ) diff --git a/dbgpt/core/awel/task/base.py b/dbgpt/core/awel/task/base.py index f0c5712bc..fb816838e 100644 --- a/dbgpt/core/awel/task/base.py +++ b/dbgpt/core/awel/task/base.py @@ -28,6 +28,14 @@ def __bool__(self): SKIP_DATA = _EMPTY_DATA_TYPE() PLACEHOLDER_DATA = _EMPTY_DATA_TYPE() + +def is_empty_data(data: Any): + """Check if the data is empty.""" + if isinstance(data, _EMPTY_DATA_TYPE): + return data in (EMPTY_DATA, SKIP_DATA) + return False + + MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]] ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]] StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]] diff --git a/dbgpt/core/awel/task/task_impl.py b/dbgpt/core/awel/task/task_impl.py index 8877c5cfe..4b42443f5 100644 --- a/dbgpt/core/awel/task/task_impl.py +++ b/dbgpt/core/awel/task/task_impl.py @@ -24,7 +24,6 @@ EMPTY_DATA, OUT, PLACEHOLDER_DATA, - SKIP_DATA, InputContext, InputSource, MapFunc, @@ -37,6 +36,7 @@ TaskState, TransformFunc, UnStreamFunc, + is_empty_data, ) logger = logging.getLogger(__name__) @@ -99,7 +99,7 @@ def new_output(self) -> TaskOutput[T]: @property def is_empty(self) -> bool: """Return True if the output data is empty.""" - return self._data == EMPTY_DATA or self._data == SKIP_DATA + return is_empty_data(self._data) @property def is_none(self) -> bool: @@ -171,7 +171,7 @@ def is_stream(self) -> bool: @property def is_empty(self) -> bool: """Return True if the output data is empty.""" - return self._data == EMPTY_DATA or self._data == SKIP_DATA + return is_empty_data(self._data) @property def is_none(self) -> bool: @@ -330,7 +330,7 @@ def _read_data(self, task_ctx: TaskContext) -> Any: """ call_data = task_ctx.call_data data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA - if data == EMPTY_DATA: + if is_empty_data(data): raise ValueError("No call data for current SimpleCallDataInputSource") return data diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 702334dc4..c77bd8b7e 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -152,7 +152,7 @@ async def _trigger_dag( raise ValueError("HttpTrigger just support one leaf node in dag") end_node = cast(BaseOperator, leaf_nodes[0]) if not streaming_response: - return await end_node.call(call_data={"data": body}) + return await end_node.call(call_data=body) else: headers = response_headers media_type = response_media_type if response_media_type else "text/event-stream" @@ -163,7 +163,7 @@ async def _trigger_dag( "Connection": "keep-alive", "Transfer-Encoding": "chunked", } - generator = await end_node.call_stream(call_data={"data": body}) + generator = await end_node.call_stream(call_data=body) background_tasks = BackgroundTasks() background_tasks.add_task(dag._after_dag_end) return StreamingResponse( diff --git a/dbgpt/core/awel/trigger/trigger_manager.py b/dbgpt/core/awel/trigger/trigger_manager.py index c9baed58d..608e68fa9 100644 --- a/dbgpt/core/awel/trigger/trigger_manager.py +++ b/dbgpt/core/awel/trigger/trigger_manager.py @@ -24,6 +24,14 @@ class TriggerManager(ABC): def register_trigger(self, trigger: Any) -> None: """Register a trigger to current manager.""" + def keep_running(self) -> bool: + """Whether keep running. + + Returns: + bool: Whether keep running, True means keep running, False means stop. + """ + return False + class HttpTriggerManager(TriggerManager): """Http trigger manager. @@ -64,6 +72,8 @@ def register_trigger(self, trigger: Any) -> None: self._trigger_map[trigger_id] = trigger def _init_app(self, system_app: SystemApp): + if not self.keep_running(): + return logger.info( f"Include router {self._router} to prefix path {self._router_prefix}" ) @@ -72,6 +82,14 @@ def _init_app(self, system_app: SystemApp): raise RuntimeError("System app not initialized") app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"]) + def keep_running(self) -> bool: + """Whether keep running. + + Returns: + bool: Whether keep running, True means keep running, False means stop. + """ + return len(self._trigger_map) > 0 + class DefaultTriggerManager(TriggerManager, BaseComponent): """Default trigger manager for AWEL. @@ -105,3 +123,11 @@ def after_register(self) -> None: """After register, init the trigger manager.""" if self.system_app: self.http_trigger._init_app(self.system_app) + + def keep_running(self) -> bool: + """Whether keep running. + + Returns: + bool: Whether keep running, True means keep running, False means stop. + """ + return self.http_trigger.keep_running() diff --git a/dbgpt/core/interface/operator/composer_operator.py b/dbgpt/core/interface/operator/composer_operator.py index a8896702a..bf7a41204 100644 --- a/dbgpt/core/interface/operator/composer_operator.py +++ b/dbgpt/core/interface/operator/composer_operator.py @@ -70,7 +70,7 @@ async def map(self, input_value: ChatComposerInput) -> ModelRequest: end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0]) # Sub dag, use the same dag context in the parent dag return await end_node.call( - call_data={"data": input_value}, dag_ctx=self.current_dag_context + call_data=input_value, dag_ctx=self.current_dag_context ) def _build_composer_dag(self) -> DAG: diff --git a/dbgpt/core/interface/operator/prompt_operator.py b/dbgpt/core/interface/operator/prompt_operator.py index 5049c7026..f31172533 100644 --- a/dbgpt/core/interface/operator/prompt_operator.py +++ b/dbgpt/core/interface/operator/prompt_operator.py @@ -150,7 +150,7 @@ class PromptBuilderOperator( ) ) - single_input = {"data": {"dialect": "mysql"}} + single_input = {"dialect": "mysql"} single_expected_messages = [ ModelMessage( content="Please write a mysql SQL count the length of a field", diff --git a/dbgpt/serve/agent/team/layout/team_awel_layout.py b/dbgpt/serve/agent/team/layout/team_awel_layout.py index c22578d30..2515bad2f 100644 --- a/dbgpt/serve/agent/team/layout/team_awel_layout.py +++ b/dbgpt/serve/agent/team/layout/team_awel_layout.py @@ -67,7 +67,7 @@ async def a_run_chat( message=start_message, sender=self, reviewer=reviewer ) final_generate_context: AgentGenerateContext = await last_node.call( - call_data={"data": start_message_context} + call_data=start_message_context ) last_message = final_generate_context.rely_messages[-1] diff --git a/examples/sdk/simple_sdk_llm_example.py b/examples/sdk/simple_sdk_llm_example.py index 4cf1254c2..f141c768b 100644 --- a/examples/sdk/simple_sdk_llm_example.py +++ b/examples/sdk/simple_sdk_llm_example.py @@ -20,8 +20,6 @@ if __name__ == "__main__": output = asyncio.run( - out_parse_task.call( - call_data={"data": {"dialect": "mysql", "table_name": "user"}} - ) + out_parse_task.call(call_data={"dialect": "mysql", "table_name": "user"}) ) print(f"output: \n\n{output}") diff --git a/examples/sdk/simple_sdk_llm_sql_example.py b/examples/sdk/simple_sdk_llm_sql_example.py index d2c344552..2eceefb8b 100644 --- a/examples/sdk/simple_sdk_llm_sql_example.py +++ b/examples/sdk/simple_sdk_llm_sql_example.py @@ -144,12 +144,10 @@ def _combine_result(self, sql_result_df, model_result: Dict) -> Dict: if __name__ == "__main__": input_data = { - "data": { - "db_name": "test_db", - "dialect": "sqlite", - "top_k": 5, - "user_input": "What is the name and age of the user with age less than 18", - } + "db_name": "test_db", + "dialect": "sqlite", + "top_k": 5, + "user_input": "What is the name and age of the user with age less than 18", } output = asyncio.run(sql_result_task.call(call_data=input_data)) print(f"\nthoughts: {output.get('thoughts')}\n") diff --git a/setup.py b/setup.py index a2b03ca27..f2fa86f33 100644 --- a/setup.py +++ b/setup.py @@ -357,7 +357,7 @@ def llama_cpp_python_cuda_requires(): def core_requires(): """ - pip install db-gpt or pip install "db-gpt[core]" + pip install dbgpt or pip install "dbgpt[core]" """ setup_spec.extras["core"] = [ "aiohttp==3.8.4", @@ -433,7 +433,7 @@ def core_requires(): def knowledge_requires(): """ - pip install "db-gpt[knowledge]" + pip install "dbgpt[knowledge]" """ setup_spec.extras["knowledge"] = [ "spacy==3.5.3", @@ -450,7 +450,7 @@ def knowledge_requires(): def llama_cpp_requires(): """ - pip install "db-gpt[llama_cpp]" + pip install "dbgpt[llama_cpp]" """ setup_spec.extras["llama_cpp"] = ["llama-cpp-python"] llama_cpp_python_cuda_requires() @@ -538,7 +538,7 @@ def quantization_requires(): def all_vector_store_requires(): """ - pip install "db-gpt[vstore]" + pip install "dbgpt[vstore]" """ setup_spec.extras["vstore"] = [ "grpcio==1.47.5", # maybe delete it @@ -549,7 +549,7 @@ def all_vector_store_requires(): def all_datasource_requires(): """ - pip install "db-gpt[datasource]" + pip install "dbgpt[datasource]" """ setup_spec.extras["datasource"] = [ @@ -567,7 +567,7 @@ def all_datasource_requires(): def openai_requires(): """ - pip install "db-gpt[openai]" + pip install "dbgpt[openai]" """ setup_spec.extras["openai"] = ["tiktoken"] if BUILD_VERSION_OPENAI: @@ -582,28 +582,28 @@ def openai_requires(): def gpt4all_requires(): """ - pip install "db-gpt[gpt4all]" + pip install "dbgpt[gpt4all]" """ setup_spec.extras["gpt4all"] = ["gpt4all"] def vllm_requires(): """ - pip install "db-gpt[vllm]" + pip install "dbgpt[vllm]" """ setup_spec.extras["vllm"] = ["vllm"] def cache_requires(): """ - pip install "db-gpt[cache]" + pip install "dbgpt[cache]" """ setup_spec.extras["cache"] = ["rocksdict"] def default_requires(): """ - pip install "db-gpt[default]" + pip install "dbgpt[default]" """ setup_spec.extras["default"] = [ # "tokenizers==0.13.3", @@ -683,7 +683,7 @@ def init_install_requires(): ) setuptools.setup( - name="db-gpt", + name="dbgpt", packages=packages, version=DB_GPT_VERSION, author="csunny",