From 725e55756350aefe913001b409665ab8af6070d5 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Wed, 28 Aug 2024 12:01:29 +0800 Subject: [PATCH 01/60] feat: Support some higher-order operators --- dbgpt/app/component_configs.py | 12 + dbgpt/app/operators/__init__.py | 4 + dbgpt/app/operators/converter.py | 186 ++++++++ dbgpt/app/operators/datasource.py | 336 +++++++++++++ dbgpt/app/operators/llm.py | 443 ++++++++++++++++++ dbgpt/app/operators/rag.py | 191 ++++++++ dbgpt/core/awel/flow/__init__.py | 2 + dbgpt/core/awel/flow/base.py | 116 ++++- dbgpt/core/awel/flow/flow_factory.py | 37 +- dbgpt/core/awel/flow/ui.py | 39 +- dbgpt/core/awel/trigger/http_trigger.py | 3 + dbgpt/core/interface/llm.py | 3 + dbgpt/core/interface/message.py | 27 +- .../core/interface/operators/llm_operator.py | 33 +- .../interface/operators/prompt_operator.py | 31 +- dbgpt/core/interface/output_parser.py | 89 +++- dbgpt/core/interface/prompt.py | 12 + dbgpt/model/cluster/client.py | 4 +- dbgpt/model/operators/llm_operator.py | 27 +- dbgpt/model/utils/chatgpt_utils.py | 9 +- dbgpt/rag/summary/db_summary_client.py | 3 +- dbgpt/serve/agent/resource/datasource.py | 67 ++- dbgpt/serve/agent/resource/knowledge.py | 1 + dbgpt/serve/flow/api/endpoints.py | 19 +- dbgpt/serve/flow/service/service.py | 8 +- dbgpt/serve/rag/operators/knowledge_space.py | 2 +- 26 files changed, 1636 insertions(+), 68 deletions(-) create mode 100644 dbgpt/app/operators/__init__.py create mode 100644 dbgpt/app/operators/converter.py create mode 100644 dbgpt/app/operators/datasource.py create mode 100644 dbgpt/app/operators/llm.py create mode 100644 dbgpt/app/operators/rag.py diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index a8a0f24d1..418d9eae1 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -60,6 +60,7 @@ def initialize_components( _initialize_openapi(system_app) # Register serve apps register_serve_apps(system_app, CFG, param.port) + _initialize_operators() def _initialize_model_cache(system_app: SystemApp, port: int): @@ -128,3 +129,14 @@ def _initialize_openapi(system_app: SystemApp): from dbgpt.app.openapi.api_v1.editor.service import EditorService system_app.register(EditorService) + + +def _initialize_operators(): + from dbgpt.app.operators.converter import StringToInteger + from dbgpt.app.operators.datasource import ( + HODatasourceExecutorOperator, + HODatasourceRetrieverOperator, + ) + from dbgpt.app.operators.llm import HOLLMOperator, HOStreamingLLMOperator + from dbgpt.app.operators.rag import HOKnowledgeOperator + from dbgpt.serve.agent.resource.datasource import DatasourceResource diff --git a/dbgpt/app/operators/__init__.py b/dbgpt/app/operators/__init__.py new file mode 100644 index 000000000..353336a34 --- /dev/null +++ b/dbgpt/app/operators/__init__.py @@ -0,0 +1,4 @@ +"""Operators package. + +This package contains all higher-order operators that are used to build workflows. +""" diff --git a/dbgpt/app/operators/converter.py b/dbgpt/app/operators/converter.py new file mode 100644 index 000000000..1115e0de4 --- /dev/null +++ b/dbgpt/app/operators/converter.py @@ -0,0 +1,186 @@ +"""Type Converter Operators.""" + +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + Parameter, + ViewMetadata, +) +from dbgpt.util.i18n_utils import _ + +_INPUTS_STRING = IOField.build_from( + _("String"), + "string", + str, + description=_("The string to be converted to other types."), +) +_INPUTS_INTEGER = IOField.build_from( + _("Integer"), + "integer", + int, + description=_("The integer to be converted to other types."), +) +_INPUTS_FLOAT = IOField.build_from( + _("Float"), + "float", + float, + description=_("The float to be converted to other types."), +) +_INPUTS_BOOLEAN = IOField.build_from( + _("Boolean"), + "boolean", + bool, + description=_("The boolean to be converted to other types."), +) + +_OUTPUTS_STRING = IOField.build_from( + _("String"), + "string", + str, + description=_("The string converted from other types."), +) +_OUTPUTS_INTEGER = IOField.build_from( + _("Integer"), + "integer", + int, + description=_("The integer converted from other types."), +) +_OUTPUTS_FLOAT = IOField.build_from( + _("Float"), + "float", + float, + description=_("The float converted from other types."), +) +_OUTPUTS_BOOLEAN = IOField.build_from( + _("Boolean"), + "boolean", + bool, + description=_("The boolean converted from other types."), +) + + +class StringToInteger(MapOperator[str, int]): + """Converts a string to an integer.""" + + metadata = ViewMetadata( + label=_("String to Integer"), + name="default_converter_string_to_integer", + description=_("Converts a string to an integer."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_STRING], + outputs=[_OUTPUTS_INTEGER], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new StringToInteger operator.""" + super().__init__(map_function=lambda x: int(x), **kwargs) + + +class StringToFloat(MapOperator[str, float]): + """Converts a string to a float.""" + + metadata = ViewMetadata( + label=_("String to Float"), + name="default_converter_string_to_float", + description=_("Converts a string to a float."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_STRING], + outputs=[_OUTPUTS_FLOAT], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new StringToFloat operator.""" + super().__init__(map_function=lambda x: float(x), **kwargs) + + +class StringToBoolean(MapOperator[str, bool]): + """Converts a string to a boolean.""" + + metadata = ViewMetadata( + label=_("String to Boolean"), + name="default_converter_string_to_boolean", + description=_("Converts a string to a boolean, true: 'true', '1', 'y'"), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[ + Parameter.build_from( + _("True Values"), + "true_values", + str, + optional=True, + default="true,1,y", + description=_("Comma-separated values that should be treated as True."), + ) + ], + inputs=[_INPUTS_STRING], + outputs=[_OUTPUTS_BOOLEAN], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, true_values: str = "true,1,y", **kwargs): + """Create a new StringToBoolean operator.""" + true_values_list = true_values.split(",") + true_values_list = [x.strip().lower() for x in true_values_list] + super().__init__(map_function=lambda x: x.lower() in true_values_list, **kwargs) + + +class IntegerToString(MapOperator[int, str]): + """Converts an integer to a string.""" + + metadata = ViewMetadata( + label=_("Integer to String"), + name="default_converter_integer_to_string", + description=_("Converts an integer to a string."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_INTEGER], + outputs=[_OUTPUTS_STRING], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new IntegerToString operator.""" + super().__init__(map_function=lambda x: str(x), **kwargs) + + +class FloatToString(MapOperator[float, str]): + """Converts a float to a string.""" + + metadata = ViewMetadata( + label=_("Float to String"), + name="default_converter_float_to_string", + description=_("Converts a float to a string."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_FLOAT], + outputs=[_OUTPUTS_STRING], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new FloatToString operator.""" + super().__init__(map_function=lambda x: str(x), **kwargs) + + +class BooleanToString(MapOperator[bool, str]): + """Converts a boolean to a string.""" + + metadata = ViewMetadata( + label=_("Boolean to String"), + name="default_converter_boolean_to_string", + description=_("Converts a boolean to a string."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_BOOLEAN], + outputs=[_OUTPUTS_STRING], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new BooleanToString operator.""" + super().__init__(map_function=lambda x: str(x), **kwargs) diff --git a/dbgpt/app/operators/datasource.py b/dbgpt/app/operators/datasource.py new file mode 100644 index 000000000..7fe16feaa --- /dev/null +++ b/dbgpt/app/operators/datasource.py @@ -0,0 +1,336 @@ +import json +import logging +from typing import List, Optional + +from dbgpt._private.config import Config +from dbgpt.agent.resource.database import DBResource +from dbgpt.core.awel import DAGContext, MapOperator +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + Parameter, + ViewMetadata, + ui, +) +from dbgpt.core.operators import BaseLLM +from dbgpt.util.i18n_utils import _ +from dbgpt.vis.tags.vis_chart import default_chart_type_prompt + +from .llm import HOContextBody + +logger = logging.getLogger(__name__) + +CFG = Config() + +_DEFAULT_CHART_TYPE = default_chart_type_prompt() + +_DEFAULT_TEMPLATE_EN = """You are a database expert. +Please answer the user's question based on the database selected by the user and some \ +of the available table structure definitions of the database. +Database name: + {db_name} +Table structure definition: + {table_info} + +Constraint: + 1.Please understand the user's intention based on the user's question, and use the \ + given table structure definition to create a grammatically correct {dialect} sql. \ + If sql is not required, answer the user's question directly.. + 2.Always limit the query to a maximum of {max_num_results} results unless the user \ + specifies in the question the specific number of rows of data he wishes to obtain. + 3.You can only use the tables provided in the table structure information to \ + generate sql. If you cannot generate sql based on the provided table structure, \ + please say: "The table structure information provided is not enough to generate \ + sql queries." It is prohibited to fabricate information at will. + 4.Please be careful not to mistake the relationship between tables and columns \ + when generating SQL. + 5.Please check the correctness of the SQL and ensure that the query performance is \ + optimized under correct conditions. + 6.Please choose the best one from the display methods given below for data \ + rendering, and put the type name into the name parameter value that returns the \ + required format. If you cannot find the most suitable one, use 'Table' as the \ + display method. , the available data display methods are as follows: {display_type} + +User Question: + {user_input} +Please think step by step and respond according to the following JSON format: + {response} +Ensure the response is correct json and can be parsed by Python json.loads. +""" + +_DEFAULT_TEMPLATE_ZH = """你是一个数据库专家. +请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题. +数据库名: + {db_name} +表结构定义: + {table_info} + +约束: + 1. 请根据用户问题理解用户意图,使用给出表结构定义创建一个语法正确的 {dialect} sql,如果不需要 \ + sql,则直接回答用户问题。 + 2. 除非用户在问题中指定了他希望获得的具体数据行数,否则始终将查询限制为最多 {max_num_results} \ + 个结果。 + 3. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:\ + “提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 + 4. 请注意生成SQL时不要弄错表和列的关系 + 5. 请检查SQL的正确性,并保证正确的情况下优化查询性能 + 6.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种\ + ,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {display_type} +用户问题: + {user_input} +请一步步思考并按照以下JSON格式回复: + {response} +确保返回正确的json并且可以被Python json.loads方法解析. +""" +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) + +_DEFAULT_RESPONSE = json.dumps( + { + "thoughts": "thoughts summary to say to user", + "sql": "SQL Query to run", + "display_type": "Data display method", + }, + ensure_ascii=False, + indent=4, +) + +_PARAMETER_DATASOURCE = Parameter.build_from( + _("Datasource"), + "datasource", + type=DBResource, + description=_("The datasource to retrieve the context"), +) +_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from( + _("Prompt Template"), + "prompt_template", + type=str, + optional=True, + default=_DEFAULT_TEMPLATE, + description=_("The prompt template to build a database prompt"), + ui=ui.DefaultUITextArea(), +) +_PARAMETER_DISPLAY_TYPE = Parameter.build_from( + _("Display Type"), + "display_type", + type=str, + optional=True, + default=_DEFAULT_CHART_TYPE, + description=_("The display type for the data"), + ui=ui.DefaultUITextArea(), +) +_PARAMETER_MAX_NUM_RESULTS = Parameter.build_from( + _("Max Number of Results"), + "max_num_results", + type=int, + optional=True, + default=50, + description=_("The maximum number of results to return"), +) +_PARAMETER_RESPONSE_FORMAT = Parameter.build_from( + _("Response Format"), + "response_format", + type=str, + optional=True, + default=_DEFAULT_RESPONSE, + description=_("The response format, default is a JSON format"), + ui=ui.DefaultUITextArea(), +) + +_PARAMETER_CONTEXT_KEY = Parameter.build_from( + _("Context Key"), + "context_key", + type=str, + optional=True, + default="context", + description=_("The key of the context, it will be used in building the prompt"), +) +_INPUTS_QUESTION = IOField.build_from( + _("User question"), + "query", + str, + description=_("The user question to retrieve table schemas from the datasource"), +) +_OUTPUTS_CONTEXT = IOField.build_from( + _("Retrieved context"), + "context", + HOContextBody, + description=_("The retrieved context from the datasource"), +) + +_INPUTS_SQL_DICT = IOField.build_from( + _("SQL dict"), + "sql_dict", + dict, + description=_("The SQL to be executed wrapped in a dictionary, generated by LLM"), +) +_OUTPUTS_SQL_RESULT = IOField.build_from( + _("SQL result"), + "sql_result", + str, + description=_("The result of the SQL execution"), +) + +_INPUTS_SQL_DICT_LIST = IOField.build_from( + _("SQL dict list"), + "sql_dict_list", + dict, + description=_( + "The SQL list to be executed wrapped in a dictionary, generated by LLM" + ), + is_list=True, +) + + +class GPTVisMixin: + async def save_view_message(self, dag_ctx: DAGContext, view: str): + """Save the view message.""" + await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view) + + +class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]): + """Retrieve the table schemas from the datasource.""" + + metadata = ViewMetadata( + label=_("Datasource Retriever Operator"), + name="higher_order_datasource_retriever_operator", + description=_("Retrieve the table schemas from the datasource."), + category=OperatorCategory.DATABASE, + parameters=[ + _PARAMETER_DATASOURCE.new(), + _PARAMETER_PROMPT_TEMPLATE.new(), + _PARAMETER_DISPLAY_TYPE.new(), + _PARAMETER_MAX_NUM_RESULTS.new(), + _PARAMETER_RESPONSE_FORMAT.new(), + _PARAMETER_CONTEXT_KEY.new(), + ], + inputs=[_INPUTS_QUESTION.new()], + outputs=[_OUTPUTS_CONTEXT.new()], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__( + self, + datasource: DBResource, + prompt_template: str = _DEFAULT_TEMPLATE, + display_type: str = _DEFAULT_CHART_TYPE, + max_num_results: int = 50, + response_format: str = _DEFAULT_RESPONSE, + context_key: Optional[str] = "context", + **kwargs, + ): + """Initialize the operator.""" + super().__init__(**kwargs) + self._datasource = datasource + self._prompt_template = prompt_template + self._display_type = display_type + self._max_num_results = max_num_results + self._response_format = response_format + self._context_key = context_key + + async def map(self, question: str) -> HOContextBody: + """Retrieve the context from the datasource.""" + db_name = self._datasource._db_name + dialect = self._datasource.dialect + schema_info = await self.blocking_func_to_async( + self._datasource.get_schema_link, + db=db_name, + question=question, + ) + context = self._prompt_template.format( + db_name=db_name, + table_info=schema_info, + dialect=dialect, + max_num_results=self._max_num_results, + display_type=self._display_type, + user_input=question, + response=self._response_format, + ) + + return HOContextBody( + context_key=self._context_key, + context=context, + ) + + +class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]): + """Execute the context from the datasource.""" + + metadata = ViewMetadata( + label=_("Datasource Executor Operator"), + name="higher_order_datasource_executor_operator", + description=_("Execute the context from the datasource."), + category=OperatorCategory.DATABASE, + parameters=[_PARAMETER_DATASOURCE.new()], + inputs=[_INPUTS_SQL_DICT.new()], + outputs=[_OUTPUTS_SQL_RESULT.new()], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, datasource: DBResource, **kwargs): + """Initialize the operator.""" + MapOperator.__init__(self, **kwargs) + self._datasource = datasource + + async def map(self, sql_dict: dict) -> str: + """Execute the context from the datasource.""" + from dbgpt.vis.tags.vis_chart import VisChart + + if not isinstance(sql_dict, dict): + raise ValueError( + "The input value of datasource executor should be a dictionary." + ) + vis = VisChart() + sql = sql_dict.get("sql") + if not sql: + return sql_dict.get("thoughts", "No SQL found in the input dictionary.") + data_df = await self._datasource.query_to_df(sql) + view = await vis.display(chart=sql_dict, data_df=data_df) + await self.save_view_message(self.current_dag_context, view) + return view + + +class HODatasourceDashboardOperator(GPTVisMixin, MapOperator[dict, str]): + """Execute the context from the datasource.""" + + metadata = ViewMetadata( + label=_("Datasource Dashboard Operator"), + name="higher_order_datasource_dashboard_operator", + description=_("Execute the context from the datasource."), + category=OperatorCategory.DATABASE, + parameters=[_PARAMETER_DATASOURCE.new()], + inputs=[_INPUTS_SQL_DICT_LIST.new()], + outputs=[_OUTPUTS_SQL_RESULT.new()], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, datasource: DBResource, **kwargs): + """Initialize the operator.""" + MapOperator.__init__(self, **kwargs) + self._datasource = datasource + + async def map(self, sql_dict_list: List[dict]) -> str: + """Execute the context from the datasource.""" + from dbgpt.vis.tags.vis_dashboard import VisDashboard + + if not isinstance(sql_dict_list, list): + raise ValueError( + "The input value of datasource executor should be a list of dictionaries." + ) + vis = VisDashboard() + chart_params = [] + for chart_item in sql_dict_list: + chart_dict = {k: v for k, v in chart_item.items()} + sql = chart_item.get("sql") + try: + data_df = await self._datasource.query_to_df(sql) + chart_dict["data"] = data_df + except Exception as e: + logger.warning(f"Sql execute failed!{str(e)}") + chart_dict["err_msg"] = str(e) + chart_params.append(chart_dict) + view = await vis.display(charts=chart_params) + await self.save_view_message(self.current_dag_context, view) + return view diff --git a/dbgpt/app/operators/llm.py b/dbgpt/app/operators/llm.py new file mode 100644 index 000000000..56b67a010 --- /dev/null +++ b/dbgpt/app/operators/llm.py @@ -0,0 +1,443 @@ +from typing import List, Literal, Optional, Tuple, Union + +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core import ( + BaseMessage, + ChatPromptTemplate, + LLMClient, + ModelOutput, + ModelRequest, + StorageConversation, +) +from dbgpt.core.awel import ( + DAG, + BaseOperator, + CommonLLMHttpRequestBody, + DAGContext, + DefaultInputContext, + InputOperator, + JoinOperator, + MapOperator, + SimpleCallDataInputSource, + TaskOutput, +) +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + OptionValue, + Parameter, + ViewMetadata, + ui, +) +from dbgpt.core.interface.operators.message_operator import ( + BaseConversationOperator, + BufferedConversationMapperOperator, + TokenBufferedConversationMapperOperator, +) +from dbgpt.core.interface.operators.prompt_operator import HistoryPromptBuilderOperator +from dbgpt.model.operators import LLMOperator, StreamingLLMOperator +from dbgpt.serve.conversation.serve import Serve as ConversationServe +from dbgpt.util.i18n_utils import _ +from dbgpt.util.tracer import root_tracer + + +class HOContextBody(BaseModel): + """Higher-order context body.""" + + context_key: str = Field( + "context", + description=_("The context key can be used as the key for formatting prompt."), + ) + context: Union[str, List[str]] = Field( + ..., + description=_("The context."), + ) + + +class BaseHOLLMOperator( + BaseConversationOperator, + JoinOperator[ModelRequest], + LLMOperator, + StreamingLLMOperator, +): + """Higher-order model request builder operator.""" + + def __init__( + self, + prompt_template: ChatPromptTemplate, + model: str = None, + llm_client: Optional[LLMClient] = None, + history_merge_mode: Literal["none", "window", "token"] = "window", + user_message_key: str = "user_input", + history_key: Optional[str] = None, + keep_start_rounds: Optional[int] = None, + keep_end_rounds: Optional[int] = None, + max_token_limit: int = 2048, + **kwargs, + ): + JoinOperator.__init__(self, combine_function=self._join_func, **kwargs) + LLMOperator.__init__(self, llm_client=llm_client, **kwargs) + StreamingLLMOperator.__init__(self, llm_client=llm_client, **kwargs) + + # User must select a history merge mode + self._history_merge_mode = history_merge_mode + self._user_message_key = user_message_key + self._has_history = history_merge_mode != "none" + self._prompt_template = prompt_template + self._model = model + self._history_key = history_key + self._str_history = False + self._keep_start_rounds = keep_start_rounds if self._has_history else 0 + self._keep_end_rounds = keep_end_rounds if self._has_history else 0 + self._max_token_limit = max_token_limit + self._sub_compose_dag = self._build_conversation_composer_dag() + + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]: + conv_serve = ConversationServe.get_instance(self.system_app) + self._storage = conv_serve.conv_storage + self._message_storage = conv_serve.message_storage + + _: TaskOutput[ModelRequest] = await JoinOperator._do_run(self, dag_ctx) + dag_ctx.current_task_context.set_task_input( + DefaultInputContext([dag_ctx.current_task_context]) + ) + if dag_ctx.streaming_call: + task_output = await StreamingLLMOperator._do_run(self, dag_ctx) + else: + task_output = await LLMOperator._do_run(self, dag_ctx) + + return task_output + + async def after_dag_end(self, event_loop_task_id: int): + model_output: Optional[ + ModelOutput + ] = await self.current_dag_context.get_from_share_data( + LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT + ) + model_output_view: Optional[ + str + ] = await self.current_dag_context.get_from_share_data( + LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW + ) + storage_conv = await self.get_storage_conversation() + end_current_round: bool = False + if model_output and storage_conv: + # Save model output message to storage + storage_conv.add_ai_message(model_output.text) + end_current_round = True + if model_output_view and storage_conv: + # Save model output view to storage + storage_conv.add_view_message(model_output_view) + end_current_round = True + if end_current_round: + # End current conversation round and flush to storage + storage_conv.end_current_round() + + async def _join_func(self, req: CommonLLMHttpRequestBody, *args): + dynamic_inputs = [] + for arg in args: + if isinstance(arg, HOContextBody): + dynamic_inputs.append(arg) + # Load and store chat history, default use InMemoryStorage. + storage_conv, history_messages = await self.blocking_func_to_async( + self._build_storage, req + ) + # Save the storage conversation to share data, for the child operators + await self.current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv + ) + + user_input = ( + req.messages[-1] if isinstance(req.messages, list) else req.messages + ) + prompt_dict = { + self._user_message_key: user_input, + } + for dynamic_input in dynamic_inputs: + if dynamic_input.context_key in prompt_dict: + raise ValueError( + f"Duplicate context key '{dynamic_input.context_key}' in upstream " + f"operators." + ) + prompt_dict[dynamic_input.context_key] = dynamic_input.context + + call_data = { + "messages": history_messages, + "prompt_dict": prompt_dict, + } + end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0] + # Sub dag, use the same dag context in the parent dag + messages = await end_node.call(call_data, dag_ctx=self.current_dag_context) + model_request = ModelRequest.build_request( + model=req.model, + messages=messages, + context=req.context, + temperature=req.temperature, + max_new_tokens=req.max_new_tokens, + span_id=root_tracer.get_current_span_id(), + echo=False, + ) + if storage_conv: + # Start new round + storage_conv.start_new_round() + storage_conv.add_user_message(user_input) + return model_request + + def _build_storage( + self, req: CommonLLMHttpRequestBody + ) -> Tuple[StorageConversation, List[BaseMessage]]: + # Create a new storage conversation, this will load the conversation from + # storage, so we must do this async + storage_conv: StorageConversation = StorageConversation( + conv_uid=req.conv_uid, + chat_mode=req.chat_mode, + user_name=req.user_name, + sys_code=req.sys_code, + conv_storage=self.storage, + message_storage=self.message_storage, + param_type="", + param_value=req.chat_param, + ) + # Get history messages from storage + history_messages: List[BaseMessage] = storage_conv.get_history_message( + include_system_message=False + ) + + return storage_conv, history_messages + + def _build_conversation_composer_dag(self) -> DAG: + with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag: + input_task = InputOperator(input_source=SimpleCallDataInputSource()) + # History transform task + if self._history_merge_mode == "token": + history_transform_task = TokenBufferedConversationMapperOperator( + model=self._model, + llm_client=self.llm_client, + max_token_limit=self._max_token_limit, + ) + else: + history_transform_task = BufferedConversationMapperOperator( + keep_start_rounds=self._keep_start_rounds, + keep_end_rounds=self._keep_end_rounds, + ) + if self._history_key: + history_key = self._history_key + else: + placeholders = self._prompt_template.get_placeholders() + if not placeholders or len(placeholders) != 1: + raise ValueError( + "The prompt template must have exactly one placeholder if " + "history_key is not provided." + ) + history_key = placeholders[0] + history_prompt_build_task = HistoryPromptBuilderOperator( + prompt=self._prompt_template, + history_key=history_key, + check_storage=False, + save_to_storage=False, + str_history=self._str_history, + ) + # Build composer dag + ( + input_task + >> MapOperator(lambda x: x["messages"]) + >> history_transform_task + >> history_prompt_build_task + ) + ( + input_task + >> MapOperator(lambda x: x["prompt_dict"]) + >> history_prompt_build_task + ) + + return composer_dag + + +_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from( + _("Prompt Template"), + "prompt_template", + ChatPromptTemplate, + description=_("The prompt template for the conversation."), +) +_PARAMETER_MODEL = Parameter.build_from( + _("Model Name"), + "model", + str, + optional=True, + default=None, + description=_("The model name."), +) + +_PARAMETER_LLM_CLIENT = Parameter.build_from( + _("LLM Client"), + "llm_client", + LLMClient, + optional=True, + default=None, + description=_( + "The LLM Client, how to connect to the LLM model, if not provided, it will use" + " the default client deployed by DB-GPT." + ), +) +_PARAMETER_HISTORY_MERGE_MODE = Parameter.build_from( + _("History Message Merge Mode"), + "history_merge_mode", + str, + optional=True, + default="none", + options=[ + OptionValue(label="No History", name="none", value="none"), + OptionValue(label="Message Window", name="window", value="window"), + OptionValue(label="Token Length", name="token", value="token"), + ], + description=_( + "The history merge mode, supports 'none', 'window' and 'token'." + " 'none': no history merge, 'window': merge by conversation window, 'token': " + "merge by token length." + ), + ui=ui.UISelect(), +) +_PARAMETER_USER_MESSAGE_KEY = Parameter.build_from( + _("User Message Key"), + "user_message_key", + str, + optional=True, + default="user_input", + description=_( + "The key of the user message in your prompt, default is 'user_input'." + ), +) +_PARAMETER_HISTORY_KEY = Parameter.build_from( + _("History Key"), + "history_key", + str, + optional=True, + default=None, + description=_( + "The chat history key, with chat history message pass to prompt template, " + "if not provided, it will parse the prompt template to get the key." + ), +) +_PARAMETER_KEEP_START_ROUNDS = Parameter.build_from( + _("Keep Start Rounds"), + "keep_start_rounds", + int, + optional=True, + default=None, + description=_("The start rounds to keep in the chat history."), +) +_PARAMETER_KEEP_END_ROUNDS = Parameter.build_from( + _("Keep End Rounds"), + "keep_end_rounds", + int, + optional=True, + default=None, + description=_("The end rounds to keep in the chat history."), +) +_PARAMETER_MAX_TOKEN_LIMIT = Parameter.build_from( + _("Max Token Limit"), + "max_token_limit", + int, + optional=True, + default=2048, + description=_("The max token limit to keep in the chat history."), +) + +_INPUTS_COMMON_LLM_REQUEST_BODY = IOField.build_from( + _("Common LLM Request Body"), + "common_llm_request_body", + CommonLLMHttpRequestBody, + _("The common LLM request body."), +) +_INPUTS_EXTRA_CONTEXT = IOField.build_from( + _("Extra Context"), + "extra_context", + HOContextBody, + _( + "Extra context for building prompt(Knowledge context, database " + "schema, etc), you can add multiple context." + ), + dynamic=True, +) +_OUTPUTS_MODEL_OUTPUT = IOField.build_from( + _("Model Output"), + "model_output", + ModelOutput, + description=_("The model output."), +) +_OUTPUTS_STREAMING_MODEL_OUTPUT = IOField.build_from( + _("Streaming Model Output"), + "streaming_model_output", + ModelOutput, + is_list=True, + description=_("The streaming model output."), +) + + +class HOLLMOperator(BaseHOLLMOperator): + metadata = ViewMetadata( + label=_("LLM Operator"), + name="higher_order_llm_operator", + category=OperatorCategory.LLM, + description=_( + "High-level LLM operator, supports multi-round conversation " + "(conversation window, token length and no multi-round)." + ), + parameters=[ + _PARAMETER_PROMPT_TEMPLATE.new(), + _PARAMETER_MODEL.new(), + _PARAMETER_LLM_CLIENT.new(), + _PARAMETER_HISTORY_MERGE_MODE.new(), + _PARAMETER_USER_MESSAGE_KEY.new(), + _PARAMETER_HISTORY_KEY.new(), + _PARAMETER_KEEP_START_ROUNDS.new(), + _PARAMETER_KEEP_END_ROUNDS.new(), + _PARAMETER_MAX_TOKEN_LIMIT.new(), + ], + inputs=[ + _INPUTS_COMMON_LLM_REQUEST_BODY.new(), + _INPUTS_EXTRA_CONTEXT.new(), + ], + outputs=[ + _OUTPUTS_MODEL_OUTPUT.new(), + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class HOStreamingLLMOperator(BaseHOLLMOperator): + metadata = ViewMetadata( + label=_("Streaming LLM Operator"), + name="higher_order_streaming_llm_operator", + category=OperatorCategory.LLM, + description=_( + "High-level streaming LLM operator, supports multi-round conversation " + "(conversation window, token length and no multi-round)." + ), + parameters=[ + _PARAMETER_PROMPT_TEMPLATE.new(), + _PARAMETER_MODEL.new(), + _PARAMETER_LLM_CLIENT.new(), + _PARAMETER_HISTORY_MERGE_MODE.new(), + _PARAMETER_USER_MESSAGE_KEY.new(), + _PARAMETER_HISTORY_KEY.new(), + _PARAMETER_KEEP_START_ROUNDS.new(), + _PARAMETER_KEEP_END_ROUNDS.new(), + _PARAMETER_MAX_TOKEN_LIMIT.new(), + ], + inputs=[ + _INPUTS_COMMON_LLM_REQUEST_BODY.new(), + _INPUTS_EXTRA_CONTEXT.new(), + ], + outputs=[ + _OUTPUTS_STREAMING_MODEL_OUTPUT.new(), + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) diff --git a/dbgpt/app/operators/rag.py b/dbgpt/app/operators/rag.py new file mode 100644 index 000000000..79d166ac0 --- /dev/null +++ b/dbgpt/app/operators/rag.py @@ -0,0 +1,191 @@ +from typing import List, Optional + +from dbgpt._private.config import Config +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + FunctionDynamicOptions, + IOField, + OperatorCategory, + OptionValue, + Parameter, + ViewMetadata, + ui, +) +from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever +from dbgpt.util.i18n_utils import _ + +from .llm import HOContextBody + +CFG = Config() + + +def _load_space_name() -> List[OptionValue]: + from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity + + spaces = KnowledgeSpaceDao().get_knowledge_space(KnowledgeSpaceEntity()) + return [ + OptionValue(label=space.name, name=space.name, value=space.name) + for space in spaces + ] + + +_PARAMETER_CONTEXT_KEY = Parameter.build_from( + _("Context Key"), + "context", + type=str, + optional=True, + default="context", + description=_("The key of the context, it will be used in building the prompt"), +) +_PARAMETER_TOP_K = Parameter.build_from( + _("Top K"), + "top_k", + type=int, + optional=True, + default=5, + description=_("The number of chunks to retrieve"), +) +_PARAMETER_SCORE_THRESHOLD = Parameter.build_from( + _("Minimum Match Score"), + "score_threshold", + type=float, + optional=True, + default=0.3, + description=_( + _( + "The minimum match score for the retrieved chunks, it will be dropped if " + "the match score is less than the threshold" + ) + ), + ui=ui.UISlider(attr=ui.UISlider.UIAttribute(min=0.0, max=1.0, step=0.1)), +) + +_PARAMETER_RE_RANKER_ENABLED = Parameter.build_from( + _("Reranker Enabled"), + "reranker_enabled", + type=bool, + optional=True, + default=None, + description=_("Whether to enable the reranker"), +) +_PARAMETER_RE_RANKER_TOP_K = Parameter.build_from( + _("Reranker Top K"), + "reranker_top_k", + type=int, + optional=True, + default=3, + description=_("The top k for the reranker"), +) + +_INPUTS_QUESTION = IOField.build_from( + _("User question"), + "query", + str, + description=_("The user question to retrieve the knowledge"), +) +_OUTPUTS_CONTEXT = IOField.build_from( + _("Retrieved context"), + "context", + HOContextBody, + description=_("The retrieved context from the knowledge space"), +) + + +class HOKnowledgeOperator(MapOperator[str, HOContextBody]): + metadata = ViewMetadata( + label=_("Knowledge Operator"), + name="higher_order_knowledge_operator", + category=OperatorCategory.RAG, + description=_( + _( + "Knowledge Operator, retrieve your knowledge(documents) from knowledge" + " space" + ) + ), + parameters=[ + Parameter.build_from( + _("Knowledge Space Name"), + "knowledge_space", + type=str, + options=FunctionDynamicOptions(func=_load_space_name), + description=_("The name of the knowledge space"), + ), + _PARAMETER_CONTEXT_KEY.new(), + _PARAMETER_TOP_K.new(), + _PARAMETER_SCORE_THRESHOLD.new(), + _PARAMETER_RE_RANKER_ENABLED.new(), + _PARAMETER_RE_RANKER_TOP_K.new(), + ], + inputs=[ + _INPUTS_QUESTION.new(), + ], + outputs=[ + _OUTPUTS_CONTEXT.new(), + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__( + self, + knowledge_space: str, + context_key: Optional[str] = "context", + top_k: Optional[int] = None, + score_threshold: Optional[float] = None, + reranker_enabled: Optional[bool] = None, + reranker_top_k: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + self._knowledge_space = knowledge_space + self._context_key = context_key + self._top_k = top_k + self._score_threshold = score_threshold + self._reranker_enabled = reranker_enabled + self._reranker_top_k = reranker_top_k + + from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory + from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker + from dbgpt.serve.rag.models.models import ( + KnowledgeSpaceDao, + KnowledgeSpaceEntity, + ) + + spaces = KnowledgeSpaceDao().get_knowledge_space( + KnowledgeSpaceEntity(name=knowledge_space) + ) + if len(spaces) != 1: + raise Exception(f"invalid space name: {knowledge_space}") + space = spaces[0] + + reranker: Optional[RerankEmbeddingsRanker] = None + + if CFG.RERANK_MODEL and self._reranker_enabled: + reranker_top_k = ( + self._reranker_top_k + if self._reranker_top_k is not None + else CFG.RERANK_TOP_K + ) + rerank_embeddings = RerankEmbeddingFactory.get_instance( + CFG.SYSTEM_APP + ).create() + reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=reranker_top_k) + if self._top_k < reranker_top_k or self._top_k < 20: + # We use reranker, so if the top_k is less than 20, + # we need to set it to 20 + self._top_k = max(reranker_top_k, 20) + + self._space_retriever = KnowledgeSpaceRetriever( + space_id=space.id, + top_k=self._top_k, + rerank=reranker, + ) + + async def map(self, query: str) -> HOContextBody: + chunks = await self._space_retriever.aretrieve_with_scores( + query, self._score_threshold + ) + return HOContextBody( + context_key=self._context_key, + context=[chunk.content for chunk in chunks], + ) diff --git a/dbgpt/core/awel/flow/__init__.py b/dbgpt/core/awel/flow/__init__.py index 80db5b7e6..0d4e268c2 100644 --- a/dbgpt/core/awel/flow/__init__.py +++ b/dbgpt/core/awel/flow/__init__.py @@ -10,6 +10,7 @@ VariablesDynamicOptions, ) from .base import ( # noqa: F401 + TAGS_ORDER_HIGH, IOField, OperatorCategory, OperatorType, @@ -33,6 +34,7 @@ "ResourceCategory", "ResourceType", "OperatorType", + "TAGS_ORDER_HIGH", "IOField", "BaseDynamicOptions", "FunctionDynamicOptions", diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 314cb2171..db8bbcb84 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -40,6 +40,9 @@ T = TypeVar("T", bound="ViewMixin") TM = TypeVar("TM", bound="TypeMetadata") +TAGS_ORDER_HIGH = "higher-order" +TAGS_ORDER_FIRST = "first-order" + def _get_type_name(type_: Type[Any]) -> str: """Get the type name of the type. @@ -143,6 +146,8 @@ def __init__(self, label: str, description: str): "agent": _CategoryDetail("Agent", "The agent operator"), "rag": _CategoryDetail("RAG", "The RAG operator"), "experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"), + "database": _CategoryDetail("Database", "Interact with the database"), + "type_converter": _CategoryDetail("Type Converter", "Convert the type"), "example": _CategoryDetail("Example", "Example operator"), } @@ -159,6 +164,8 @@ class OperatorCategory(str, Enum): AGENT = "agent" RAG = "rag" EXPERIMENTAL = "experimental" + DATABASE = "database" + TYPE_CONVERTER = "type_converter" EXAMPLE = "example" def label(self) -> str: @@ -202,6 +209,7 @@ class OperatorType(str, Enum): "embeddings": _CategoryDetail("Embeddings", "The embeddings resource"), "rag": _CategoryDetail("RAG", "The resource"), "vector_store": _CategoryDetail("Vector Store", "The vector store resource"), + "database": _CategoryDetail("Database", "Interact with the database"), "example": _CategoryDetail("Example", "The example resource"), } @@ -219,6 +227,7 @@ class ResourceCategory(str, Enum): EMBEDDINGS = "embeddings" RAG = "rag" VECTOR_STORE = "vector_store" + DATABASE = "database" EXAMPLE = "example" def label(self) -> str: @@ -372,32 +381,41 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: "value": values.get("value"), "default": values.get("default"), } + is_list = values.get("is_list") or False if type_cls: for k, v in to_handle_values.items(): if v: - handled_v = cls._covert_to_real_type(type_cls, v) + handled_v = cls._covert_to_real_type(type_cls, v, is_list) values[k] = handled_v return values @classmethod - def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any: - if type_cls and v is not None: - typed_value: Any = v + def _covert_to_real_type(cls, type_cls: str, v: Any, is_list: bool) -> Any: + def _parse_single_value(vv: Any) -> Any: + typed_value: Any = vv try: # Try to convert the value to the type. if type_cls == "builtins.str": - typed_value = str(v) + typed_value = str(vv) elif type_cls == "builtins.int": - typed_value = int(v) + typed_value = int(vv) elif type_cls == "builtins.float": - typed_value = float(v) + typed_value = float(vv) elif type_cls == "builtins.bool": - if str(v).lower() in ["false", "0", "", "no", "off"]: + if str(vv).lower() in ["false", "0", "", "no", "off"]: return False - typed_value = bool(v) + typed_value = bool(vv) return typed_value except ValueError: - raise ValidationError(f"Value '{v}' is not valid for type {type_cls}") + raise ValidationError(f"Value '{vv}' is not valid for type {type_cls}") + + if type_cls and v is not None: + if not is_list: + _parse_single_value(v) + else: + if not isinstance(v, list): + raise ValidationError(f"Value '{v}' is not a list.") + return [_parse_single_value(vv) for vv in v] return v def get_typed_value(self) -> Any: @@ -413,11 +431,11 @@ def get_typed_value(self) -> Any: if is_variables and self.value is not None and isinstance(self.value, str): return VariablesPlaceHolder(self.name, self.value) else: - return self._covert_to_real_type(self.type_cls, self.value) + return self._covert_to_real_type(self.type_cls, self.value, self.is_list) def get_typed_default(self) -> Any: """Get the typed default.""" - return self._covert_to_real_type(self.type_cls, self.default) + return self._covert_to_real_type(self.type_cls, self.default, self.is_list) @classmethod def build_from( @@ -499,7 +517,10 @@ def to_dict(self) -> Dict: values = self.options.option_values() dict_value["options"] = [value.to_dict() for value in values] else: - dict_value["options"] = [value.to_dict() for value in self.options] + dict_value["options"] = [ + value.to_dict() if not isinstance(value, dict) else value + for value in self.options + ] if self.ui: dict_value["ui"] = self.ui.to_dict() @@ -594,6 +615,17 @@ def to_runnable_parameter( value = view_value return {self.name: value} + def new(self: TM) -> TM: + """Copy the metadata.""" + new_obj = self.__class__( + **self.model_dump(exclude_defaults=True, exclude={"ui", "options"}) + ) + if self.ui: + new_obj.ui = self.ui + if self.options: + new_obj.options = self.options + return new_obj + class BaseResource(Serializable, BaseModel): """The base resource.""" @@ -644,6 +676,17 @@ class IOField(Resource): description="Whether current field is list", examples=[True, False], ) + dynamic: bool = Field( + default=False, + description="Whether current field is dynamic", + examples=[True, False], + ) + dynamic_minimum: int = Field( + default=0, + description="The minimum count of the dynamic field, only valid when dynamic is" + " True", + examples=[0, 1, 2], + ) @classmethod def build_from( @@ -653,6 +696,8 @@ def build_from( type: Type, description: Optional[str] = None, is_list: bool = False, + dynamic: bool = False, + dynamic_minimum: int = 0, ): """Build the resource from the type.""" type_name = type.__qualname__ @@ -664,8 +709,22 @@ def build_from( type_cls=type_cls, is_list=is_list, description=description or label, + dynamic=dynamic, + dynamic_minimum=dynamic_minimum, ) + @model_validator(mode="before") + @classmethod + def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Pre fill the metadata.""" + if not isinstance(values, dict): + return values + if "dynamic" not in values: + values["dynamic"] = False + if "dynamic_minimum" not in values: + values["dynamic_minimum"] = 0 + return values + class BaseMetadata(BaseResource): """The base metadata.""" @@ -808,9 +867,40 @@ def get_origin_id(self) -> str: split_ids = self.id.split("_") return "_".join(split_ids[:-1]) + def _parse_ui_size(self) -> Optional[str]: + """Parse the ui size.""" + if not self.parameters: + return None + parameters_size = set() + for parameter in self.parameters: + if parameter.ui and parameter.ui.size: + parameters_size.add(parameter.ui.size) + for size in ["large", "middle", "small"]: + if size in parameters_size: + return size + return None + def to_dict(self) -> Dict: """Convert current metadata to json dict.""" + from .ui import _size_to_order + dict_value = model_to_dict(self, exclude={"parameters"}) + tags = dict_value.get("tags") + if not tags: + tags = {"ui_version": "flow2.0"} + elif isinstance(tags, dict) and "ui_version" not in tags: + tags["ui_version"] = "flow2.0" + + parsed_ui_size = self._parse_ui_size() + if parsed_ui_size: + exist_size = tags.get("ui_size") + if not exist_size or _size_to_order(parsed_ui_size) > _size_to_order( + exist_size + ): + # Use the higher order size as current size. + tags["ui_size"] = parsed_ui_size + + dict_value["tags"] = tags dict_value["parameters"] = [ parameter.to_dict() for parameter in self.parameters ] diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index 87b828971..fe7a83f50 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -97,6 +97,12 @@ def parse_data(cls, value: Any): return ResourceMetadata(**value) raise ValueError("Unable to infer the type for `data`") + def to_dict(self) -> Dict[str, Any]: + """Convert to dict.""" + dict_value = model_to_dict(self, exclude={"data"}) + dict_value["data"] = self.data.to_dict() + return dict_value + class FlowEdgeData(BaseModel): """Edge data in a flow.""" @@ -166,6 +172,12 @@ class FlowData(BaseModel): edges: List[FlowEdgeData] = Field(..., description="Edges in the flow") viewport: FlowPositionData = Field(..., description="Viewport of the flow") + def to_dict(self) -> Dict[str, Any]: + """Convert to dict.""" + dict_value = model_to_dict(self, exclude={"nodes"}) + dict_value["nodes"] = [n.to_dict() for n in self.nodes] + return dict_value + class _VariablesRequestBase(BaseModel): key: str = Field( @@ -518,9 +530,24 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["name"] = name return values + def model_dump(self, **kwargs): + """Override the model dump method.""" + exclude = kwargs.get("exclude", set()) + if "flow_dag" not in exclude: + exclude.add("flow_dag") + if "flow_data" not in exclude: + exclude.add("flow_data") + kwargs["exclude"] = exclude + common_dict = super().model_dump(**kwargs) + if self.flow_dag: + common_dict["flow_dag"] = None + if self.flow_data: + common_dict["flow_data"] = self.flow_data.to_dict() + return common_dict + def to_dict(self) -> Dict[str, Any]: """Convert to dict.""" - return model_to_dict(self, exclude={"flow_dag"}) + return model_to_dict(self, exclude={"flow_dag", "flow_data"}) def get_variables_dict(self) -> List[Dict[str, Any]]: """Get the variables dict.""" @@ -943,11 +970,17 @@ def fill_flow_panel(flow_panel: FlowPanel): new_param = input_parameters[i.name] i.label = new_param.label i.description = new_param.description + i.dynamic = new_param.dynamic + i.is_list = new_param.is_list + i.dynamic_minimum = new_param.dynamic_minimum for i in node.data.outputs: if i.name in output_parameters: new_param = output_parameters[i.name] i.label = new_param.label i.description = new_param.description + i.dynamic = new_param.dynamic + i.is_list = new_param.is_list + i.dynamic_minimum = new_param.dynamic_minimum else: data = cast(ResourceMetadata, node.data) key = data.get_origin_id() @@ -972,6 +1005,8 @@ def fill_flow_panel(flow_panel: FlowPanel): param.options = new_param.get_dict_options() # type: ignore param.default = new_param.default param.placeholder = new_param.placeholder + param.alias = new_param.alias + param.ui = new_param.ui except (FlowException, ValueError) as e: logger.warning(f"Unable to fill the flow panel: {e}") diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 928755a20..efe3d05e0 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union -from dbgpt._private.pydantic import BaseModel, Field, model_to_dict +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict, model_validator from dbgpt.core.interface.serialization import Serializable from .exceptions import FlowUIComponentException @@ -25,6 +25,16 @@ "code_editor", ] +_UI_SIZE_TYPE = Literal["large", "middle", "small"] +_SIZE_ORDER = {"large": 6, "middle": 4, "small": 2} + + +def _size_to_order(size: str) -> int: + """Convert size to order.""" + if size not in _SIZE_ORDER: + return -1 + return _SIZE_ORDER[size] + class RefreshableMixin(BaseModel): """Refreshable mixin.""" @@ -81,6 +91,10 @@ class UIAttribute(BaseModel): ) ui_type: _UI_TYPE = Field(..., description="UI component type") + size: Optional[_UI_SIZE_TYPE] = Field( + None, + description="The size of the component(small, middle, large)", + ) attr: Optional[UIAttribute] = Field( None, @@ -266,6 +280,27 @@ class AutoSize(BaseModel): description="The attributes of the component", ) + @model_validator(mode="after") + def check_size(self) -> "UITextArea": + """Check the size. + + Automatically set the size to large if the max_rows is greater than 10. + """ + attr = self.attr + auto_size = attr.auto_size if attr else None + if not attr or not auto_size or isinstance(auto_size, bool): + return self + max_rows = ( + auto_size.max_rows + if isinstance(auto_size, self.UIAttribute.AutoSize) + else None + ) + size = self.size + if not size and max_rows and max_rows > 10: + # Automatically set the size to large if the max_rows is greater than 10 + self.size = "large" + return self + class UIAutoComplete(UIInput): """Auto complete component.""" @@ -450,7 +485,7 @@ class DefaultUITextArea(UITextArea): attr: Optional[UITextArea.UIAttribute] = Field( default_factory=lambda: UITextArea.UIAttribute( - auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40) + auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=20) ), description="The attributes of the component", ) diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 8f0298297..33692a423 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -29,6 +29,7 @@ from ..dag.base import DAG from ..flow import ( + TAGS_ORDER_HIGH, IOField, OperatorCategory, OperatorType, @@ -965,6 +966,7 @@ class CommonLLMHttpTrigger(HttpTrigger): _PARAMETER_MEDIA_TYPE.new(), _PARAMETER_STATUS_CODE.new(), ], + tags={"order": TAGS_ORDER_HIGH}, ) def __init__( @@ -1203,6 +1205,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]): "User input parsed operator, parse the user input from request body and " "return as a string" ), + tags={"order": TAGS_ORDER_HIGH}, ) def __init__(self, key: str = "user_input", **kwargs): diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index e6a5d24d4..94de92a03 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -195,6 +195,9 @@ class ModelRequest: temperature: Optional[float] = None """The temperature of the model inference.""" + top_p: Optional[float] = None + """The top p of the model inference.""" + max_new_tokens: Optional[int] = None """The maximum number of tokens to generate.""" diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 50a7b39e5..f67b83cb8 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -317,6 +317,25 @@ def messages_to_string( """ return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix) + @staticmethod + def parse_user_message(messages: List[ModelMessage]) -> str: + """Parse user message from messages. + + Args: + messages (List[ModelMessage]): The all messages in the conversation. + + Returns: + str: The user message + """ + lass_user_message = None + for message in messages[::-1]: + if message.role == ModelMessageRoleType.HUMAN: + lass_user_message = message.content + break + if not lass_user_message: + raise ValueError("No user message") + return lass_user_message + _SingleRoundMessage = List[BaseMessage] _MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]] @@ -1244,9 +1263,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]: content=ai_message.content, index=ai_message.index, round_index=ai_message.round_index, - additional_kwargs=ai_message.additional_kwargs.copy() - if ai_message.additional_kwargs - else {}, + additional_kwargs=( + ai_message.additional_kwargs.copy() + if ai_message.additional_kwargs + else {} + ), ) current_round.append(view_message) return sum(messages_by_round, []) diff --git a/dbgpt/core/interface/operators/llm_operator.py b/dbgpt/core/interface/operators/llm_operator.py index 45863d0a9..628c2f59f 100644 --- a/dbgpt/core/interface/operators/llm_operator.py +++ b/dbgpt/core/interface/operators/llm_operator.py @@ -246,10 +246,16 @@ class BaseLLM: SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name" SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output" + SHARE_DATA_KEY_MODEL_OUTPUT_VIEW = "share_data_key_model_output_view" - def __init__(self, llm_client: Optional[LLMClient] = None): + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + ): """Create a new LLM operator.""" self._llm_client = llm_client + self._save_model_output = save_model_output @property def llm_client(self) -> LLMClient: @@ -262,9 +268,10 @@ async def save_model_output( self, current_dag_context: DAGContext, model_output: ModelOutput ) -> None: """Save the model output to the share data.""" - await current_dag_context.save_to_share_data( - self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output - ) + if self._save_model_output: + await current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output + ) class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): @@ -276,9 +283,14 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): This operator will generate a no streaming response. """ - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): """Create a new LLM operator.""" - super().__init__(llm_client=llm_client) + super().__init__(llm_client=llm_client, save_model_output=save_model_output) MapOperator.__init__(self, **kwargs) async def map(self, request: ModelRequest) -> ModelOutput: @@ -309,13 +321,18 @@ class BaseStreamingLLMOperator( This operator will generate streaming response. """ - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): """Create a streaming operator for a LLM. Args: llm_client (LLMClient, optional): The LLM client. Defaults to None. """ - super().__init__(llm_client=llm_client) + super().__init__(llm_client=llm_client, save_model_output=save_model_output) BaseOperator.__init__(self, **kwargs) async def streamify( # type: ignore diff --git a/dbgpt/core/interface/operators/prompt_operator.py b/dbgpt/core/interface/operators/prompt_operator.py index 7d97230ac..241d8915f 100644 --- a/dbgpt/core/interface/operators/prompt_operator.py +++ b/dbgpt/core/interface/operators/prompt_operator.py @@ -4,14 +4,10 @@ from typing import Any, Dict, List, Optional, Union from dbgpt._private.pydantic import model_validator -from dbgpt.core import ( - ModelMessage, - ModelMessageRoleType, - ModelOutput, - StorageConversation, -) +from dbgpt.core import ModelMessage, ModelOutput, StorageConversation from dbgpt.core.awel import JoinOperator, MapOperator from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, IOField, OperatorCategory, OperatorType, @@ -42,6 +38,7 @@ name="common_chat_prompt_template", category=ResourceCategory.PROMPT, description=_("The operator to build the prompt with static prompt."), + tags={"order": TAGS_ORDER_HIGH}, parameters=[ Parameter.build_from( label=_("System Message"), @@ -101,9 +98,10 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: class BasePromptBuilderOperator(BaseConversationOperator, ABC): """The base prompt builder operator.""" - def __init__(self, check_storage: bool, **kwargs): + def __init__(self, check_storage: bool, save_to_storage: bool = True, **kwargs): """Create a new prompt builder operator.""" super().__init__(check_storage=check_storage, **kwargs) + self._save_to_storage = save_to_storage async def format_prompt( self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any] @@ -122,8 +120,9 @@ async def format_prompt( pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables} messages = prompt.format_messages(**pass_kwargs) model_messages = ModelMessage.from_base_messages(messages) - # Start new round conversation, and save user message to storage - await self.start_new_round_conv(model_messages) + if self._save_to_storage: + # Start new round conversation, and save user message to storage + await self.start_new_round_conv(model_messages) return model_messages async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: @@ -132,13 +131,7 @@ async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: Args: messages (List[ModelMessage]): The messages. """ - lass_user_message = None - for message in messages[::-1]: - if message.role == ModelMessageRoleType.HUMAN: - lass_user_message = message.content - break - if not lass_user_message: - raise ValueError("No user message") + lass_user_message = ModelMessage.parse_user_message(messages) storage_conv: Optional[ StorageConversation ] = await self.get_storage_conversation() @@ -150,6 +143,8 @@ async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: async def after_dag_end(self, event_loop_task_id: int): """Execute after the DAG finished.""" + if not self._save_to_storage: + return # Save the storage conversation to storage after the whole DAG finished storage_conv: Optional[ StorageConversation @@ -422,7 +417,7 @@ def __init__( self._prompt = prompt self._history_key = history_key self._str_history = str_history - BasePromptBuilderOperator.__init__(self, check_storage=check_storage) + BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) @rearrange_args_by_type @@ -455,7 +450,7 @@ def __init__( """Create a new history dynamic prompt builder operator.""" self._history_key = history_key self._str_history = str_history - BasePromptBuilderOperator.__init__(self, check_storage=check_storage) + BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) @rearrange_args_by_type diff --git a/dbgpt/core/interface/output_parser.py b/dbgpt/core/interface/output_parser.py index faf29bfff..31e91b9f3 100644 --- a/dbgpt/core/interface/output_parser.py +++ b/dbgpt/core/interface/output_parser.py @@ -13,7 +13,13 @@ from dbgpt.core import ModelOutput from dbgpt.core.awel import MapOperator -from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + OperatorType, + ViewMetadata, +) from dbgpt.util.i18n_utils import _ T = TypeVar("T") @@ -271,7 +277,7 @@ async def map(self, input_value: ModelOutput) -> Any: if self.current_dag_context.streaming_call: return self.parse_model_stream_resp_ex(input_value, 0) else: - return self.parse_model_nostream_resp(input_value, "###") + return self.parse_model_nostream_resp(input_value, "#####################") def _parse_model_response(response: ResponseTye): @@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye): class SQLOutputParser(BaseOutputParser): """Parse the SQL output of an LLM call.""" + metadata = ViewMetadata( + label=_("SQL Output Parser"), + name="default_sql_output_parser", + category=OperatorCategory.OUTPUT_PARSER, + description=_("Parse the SQL output of an LLM call."), + parameters=[], + inputs=[ + IOField.build_from( + _("Model Output"), + "model_output", + ModelOutput, + description=_("The model output of upstream."), + ) + ], + outputs=[ + IOField.build_from( + _("Dict SQL Output"), + "dict", + dict, + description=_("The dict output after parsing."), + ) + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + def __init__(self, is_stream_out: bool = False, **kwargs): """Create a new SQL output parser.""" super().__init__(is_stream_out=is_stream_out, **kwargs) @@ -302,3 +333,57 @@ def parse_model_nostream_resp(self, response: ResponseTye, sep: str): model_out_text = super().parse_model_nostream_resp(response, sep) clean_str = super().parse_prompt_response(model_out_text) return json.loads(clean_str, strict=True) + + +class SQLListOutputParser(BaseOutputParser): + """Parse the SQL list output of an LLM call.""" + + metadata = ViewMetadata( + label=_("SQL List Output Parser"), + name="default_sql_list_output_parser", + category=OperatorCategory.OUTPUT_PARSER, + description=_( + "Parse the SQL list output of an LLM call, mostly used for dashboard." + ), + parameters=[], + inputs=[ + IOField.build_from( + _("Model Output"), + "model_output", + ModelOutput, + description=_("The model output of upstream."), + ) + ], + outputs=[ + IOField.build_from( + _("List SQL Output"), + "list", + dict, + is_list=True, + description=_("The list output after parsing."), + ) + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, is_stream_out: bool = False, **kwargs): + """Create a new SQL list output parser.""" + super().__init__(is_stream_out=is_stream_out, **kwargs) + + def parse_model_nostream_resp(self, response: ResponseTye, sep: str): + """Parse the output of an LLM call.""" + from dbgpt.util.json_utils import find_json_objects + + model_out_text = super().parse_model_nostream_resp(response, sep) + json_objects = find_json_objects(model_out_text) + json_count = len(json_objects) + if json_count < 1: + raise ValueError("Unable to obtain valid output.") + + parsed_json_list = json_objects[0] + if not isinstance(parsed_json_list, list): + if isinstance(parsed_json_list, dict): + return [parsed_json_list] + else: + raise ValueError("Invalid output format.") + return parsed_json_list diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py index 99c4b9b10..d1d025d0a 100644 --- a/dbgpt/core/interface/prompt.py +++ b/dbgpt/core/interface/prompt.py @@ -254,6 +254,18 @@ def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["input_variables"] = sorted(input_variables) return values + def get_placeholders(self) -> List[str]: + """Get all placeholders in the prompt template. + + Returns: + List[str]: The placeholders. + """ + placeholders = set() + for message in self.messages: + if isinstance(message, MessagesPlaceholder): + placeholders.add(message.variable_name) + return sorted(placeholders) + @dataclasses.dataclass class PromptTemplateIdentifier(ResourceIdentifier): diff --git a/dbgpt/model/cluster/client.py b/dbgpt/model/cluster/client.py index 7e0aa0214..d58645cb7 100644 --- a/dbgpt/model/cluster/client.py +++ b/dbgpt/model/cluster/client.py @@ -42,13 +42,13 @@ class DefaultLLMClient(LLMClient): Args: worker_manager (WorkerManager): worker manager instance. - auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to False. + auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to True. """ def __init__( self, worker_manager: Optional[WorkerManager] = None, - auto_convert_message: bool = False, + auto_convert_message: bool = True, ): self._worker_manager = worker_manager self._auto_covert_message = auto_convert_message diff --git a/dbgpt/model/operators/llm_operator.py b/dbgpt/model/operators/llm_operator.py index 56eee1e3e..02f14fe73 100644 --- a/dbgpt/model/operators/llm_operator.py +++ b/dbgpt/model/operators/llm_operator.py @@ -24,8 +24,13 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC): This class extends BaseOperator by adding LLM capabilities. """ - def __init__(self, default_client: Optional[LLMClient] = None, **kwargs): - super().__init__(default_client) + def __init__( + self, + default_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): + super().__init__(default_client, save_model_output=save_model_output) @property def llm_client(self) -> LLMClient: @@ -95,8 +100,13 @@ class LLMOperator(MixinLLMOperator, BaseLLMOperator): ], ) - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): + super().__init__(llm_client, save_model_output=save_model_output) BaseLLMOperator.__init__(self, llm_client, **kwargs) @@ -144,6 +154,11 @@ class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator): ], ) - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): + super().__init__(llm_client, save_model_output=save_model_output) BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs) diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index 057a04bf5..51c0fcae3 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -16,7 +16,13 @@ from dbgpt._private.pydantic import model_to_json from dbgpt.core.awel import TransformStreamAbsOperator -from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + OperatorType, + ViewMetadata, +) from dbgpt.core.interface.llm import ModelOutput from dbgpt.core.operators import BaseLLM from dbgpt.util.i18n_utils import _ @@ -184,6 +190,7 @@ class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str] ), ) ], + tags={"order": TAGS_ORDER_HIGH}, ) async def transform_stream(self, model_output: AsyncIterator[ModelOutput]): diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index de5ee83ff..8ce9a79e6 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -2,6 +2,7 @@ import logging import traceback +from typing import List from dbgpt._private.config import Config from dbgpt.component import SystemApp @@ -46,7 +47,7 @@ def db_summary_embedding(self, dbname, db_type): logger.info("db summary embedding success") - def get_db_summary(self, dbname, query, topk): + def get_db_summary(self, dbname, query, topk) -> List[str]: """Get user query related tables info.""" from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig diff --git a/dbgpt/serve/agent/resource/datasource.py b/dbgpt/serve/agent/resource/datasource.py index 5e37cdd0c..0be2127dd 100644 --- a/dbgpt/serve/agent/resource/datasource.py +++ b/dbgpt/serve/agent/resource/datasource.py @@ -3,14 +3,41 @@ from typing import Any, List, Optional, Type, Union, cast from dbgpt._private.config import Config -from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource +from dbgpt.agent.resource.database import ( + _DEFAULT_PROMPT_TEMPLATE, + _DEFAULT_PROMPT_TEMPLATE_ZH, + DBParameters, + RDBMSConnectorResource, +) +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + FunctionDynamicOptions, + OptionValue, + Parameter, + ResourceCategory, + register_resource, +) from dbgpt.util import ParameterDescription +from dbgpt.util.i18n_utils import _ CFG = Config() logger = logging.getLogger(__name__) +def _load_datasource() -> List[OptionValue]: + dbs = CFG.local_db_manager.get_db_list() + results = [ + OptionValue( + label="[" + db["db_type"] + "]" + db["db_name"], + name=db["db_name"], + value=db["db_name"], + ) + for db in dbs + ] + return results + + @dataclasses.dataclass class DatasourceDBParameters(DBParameters): """The DB parameters for the datasource.""" @@ -57,6 +84,44 @@ def from_dict( return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields) +@register_resource( + _("Datasource Resource"), + "datasource", + category=ResourceCategory.DATABASE, + description=_( + "Connect to a datasource(retrieve table schemas and execute SQL to fetch data)." + ), + tags={"order": TAGS_ORDER_HIGH}, + parameters=[ + Parameter.build_from( + _("Datasource Name"), + "name", + str, + optional=True, + default="datasource", + description=_("The name of the datasource, default is 'datasource'."), + ), + Parameter.build_from( + _("DB Name"), + "db_name", + str, + description=_("The name of the database."), + options=FunctionDynamicOptions(func=_load_datasource), + ), + Parameter.build_from( + _("Prompt Template"), + "prompt_template", + str, + optional=True, + default=( + _DEFAULT_PROMPT_TEMPLATE_ZH + if CFG.LANGUAGE == "zh" + else _DEFAULT_PROMPT_TEMPLATE + ), + description=_("The prompt template to build a database prompt."), + ), + ], +) class DatasourceResource(RDBMSConnectorResource): def __init__(self, name: str, db_name: Optional[str] = None, **kwargs): conn = CFG.local_db_manager.get_connector(db_name) diff --git a/dbgpt/serve/agent/resource/knowledge.py b/dbgpt/serve/agent/resource/knowledge.py index 90359be4c..65c062415 100644 --- a/dbgpt/serve/agent/resource/knowledge.py +++ b/dbgpt/serve/agent/resource/knowledge.py @@ -64,6 +64,7 @@ class KnowledgeSpaceRetrieverResource(RetrieverResource): """Knowledge Space retriever resource.""" def __init__(self, name: str, space_name: str, context: Optional[dict] = None): + # TODO: Build the retriever in a thread pool, it will block the event loop retriever = KnowledgeSpaceRetriever( space_id=space_name, top_k=context.get("top_k", None) if context else 4, diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 28f05d532..c6148f994 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -133,7 +133,10 @@ async def create( Returns: ServerResponse: The response """ - return Result.succ(service.create_and_save_dag(request)) + res = await blocking_func_to_async( + global_system_app, service.create_and_save_dag, request + ) + return Result.succ(res) @router.put( @@ -153,7 +156,8 @@ async def update( Returns: ServerResponse: The response """ - return Result.succ(service.update_flow(request)) + res = await blocking_func_to_async(global_system_app, service.update_flow, request) + return Result.succ(res) @router.delete("/flows/{uid}") @@ -173,9 +177,7 @@ async def delete( @router.get("/flows/{uid}") -async def get_flows( - uid: str, service: Service = Depends(get_service) -) -> Result[ServerResponse]: +async def get_flows(uid: str, service: Service = Depends(get_service)): """Get a Flow entity by uid Args: @@ -188,7 +190,7 @@ async def get_flows( flow = service.get({"uid": uid}) if not flow: raise HTTPException(status_code=404, detail=f"Flow {uid} not found") - return Result.succ(flow) + return Result.succ(flow.model_dump()) @router.get( @@ -464,7 +466,10 @@ async def import_flow( status_code=400, detail=f"invalid file extension {file_extension}" ) if save_flow: - return Result.succ(service.create_and_save_dag(flow)) + res = await blocking_func_to_async( + global_system_app, service.create_and_save_dag, flow + ) + return Result.succ(res) else: return Result.succ(flow) diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 15c9d5ceb..3aac0b24c 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -27,7 +27,7 @@ ChatCompletionStreamResponse, DeltaMessage, ) -from dbgpt.serve.core import BaseService +from dbgpt.serve.core import BaseService, blocking_func_to_async from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata._base_dao import QUERY_SPEC from dbgpt.util.dbgpts.loader import DBGPTsLoader @@ -590,7 +590,11 @@ async def debug_flow( """ from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata - dag = self._flow_factory.build(request.flow) + dag = await blocking_func_to_async( + self._system_app, + self._flow_factory.build, + request.flow, + ) leaf_nodes = dag.leaf_nodes if len(leaf_nodes) != 1: raise ValueError("Chat Flow just support one leaf node in dag") diff --git a/dbgpt/serve/rag/operators/knowledge_space.py b/dbgpt/serve/rag/operators/knowledge_space.py index c37495ed5..3d2e1d846 100644 --- a/dbgpt/serve/rag/operators/knowledge_space.py +++ b/dbgpt/serve/rag/operators/knowledge_space.py @@ -223,7 +223,7 @@ def __init__( self._prompt = prompt self._history_key = history_key self._str_history = str_history - BasePromptBuilderOperator.__init__(self, check_storage=check_storage) + BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_context, **kwargs) @rearrange_args_by_type From 1f676b9ebf077ca0355fe7ffc815fde7288650a9 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Wed, 28 Aug 2024 16:48:50 +0800 Subject: [PATCH 02/60] feat: Support endpoint placeholder --- dbgpt/core/awel/trigger/http_trigger.py | 51 ++++++++++++++-------- dbgpt/core/awel/trigger/trigger_manager.py | 9 ++-- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 33692a423..6e17be15e 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -58,6 +58,8 @@ logger = logging.getLogger(__name__) +ENDPOINT_PLACEHOLDER_DAG_ID = "{dag_id}" + class AWELHttpError(RuntimeError): """AWEL Http Error.""" @@ -465,14 +467,11 @@ def mount_to_router( router (APIRouter): The router to mount the trigger. global_prefix (Optional[str], optional): The global prefix of the router. """ - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint dynamic_route_function = self._create_route_func() router.api_route( - self._endpoint, + endpoint, methods=self._methods, response_model=self._response_model, status_code=self._status_code, @@ -498,11 +497,9 @@ def mount_to_app( """ from dbgpt.util.fastapi import PriorityAPIRouter - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint dynamic_route_function = self._create_route_func() router = cast(PriorityAPIRouter, app.router) router.add_api_route( @@ -533,17 +530,28 @@ def remove_from_app( """ from fastapi import APIRouter - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint app_router = cast(APIRouter, app.router) for i, r in enumerate(app_router.routes): if r.path_format == path: # type: ignore # TODO, remove with path and methods del app_router.routes[i] + def _resolved_endpoint(self) -> str: + """Get the resolved endpoint. + + Replace the placeholder {dag_id} with the real dag_id. + """ + endpoint = self._endpoint + if ENDPOINT_PLACEHOLDER_DAG_ID not in endpoint: + return endpoint + if not self.dag: + raise AWELHttpError("DAG is not set") + dag_id = self.dag.dag_id + return endpoint.replace(ENDPOINT_PLACEHOLDER_DAG_ID, dag_id) + def _trigger_mode(self) -> str: if ( self._req_body @@ -959,7 +967,14 @@ class CommonLLMHttpTrigger(HttpTrigger): ), ], parameters=[ - _PARAMETER_ENDPOINT.new(), + Parameter.build_from( + _("API Endpoint"), + "endpoint", + str, + optional=True, + default="/example/" + ENDPOINT_PLACEHOLDER_DAG_ID, + description=_("The API endpoint"), + ), _PARAMETER_METHODS_POST_PUT.new(), _PARAMETER_STREAMING_RESPONSE.new(), _PARAMETER_RESPONSE_BODY.new(), @@ -971,7 +986,7 @@ class CommonLLMHttpTrigger(HttpTrigger): def __init__( self, - endpoint: str, + endpoint: str = "/example/" + ENDPOINT_PLACEHOLDER_DAG_ID, methods: Optional[Union[str, List[str]]] = "POST", streaming_response: bool = False, http_response_body: Optional[Type[BaseHttpBody]] = None, diff --git a/dbgpt/core/awel/trigger/trigger_manager.py b/dbgpt/core/awel/trigger/trigger_manager.py index 45b040147..94563226e 100644 --- a/dbgpt/core/awel/trigger/trigger_manager.py +++ b/dbgpt/core/awel/trigger/trigger_manager.py @@ -81,7 +81,8 @@ def register_trigger( raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger") trigger_id = trigger.node_id if trigger_id not in self._trigger_map: - path = join_paths(self._router_prefix, trigger._endpoint) + real_endpoint = trigger._resolved_endpoint() + path = join_paths(self._router_prefix, real_endpoint) methods = trigger._methods # Check whether the route is already registered self._register_route_tables(path, methods) @@ -116,9 +117,9 @@ def unregister_trigger(self, trigger: Any, system_app: SystemApp) -> None: if not app: raise ValueError("System app not initialized") trigger.remove_from_app(app, self._router_prefix) - self._unregister_route_tables( - join_paths(self._router_prefix, trigger._endpoint), trigger._methods - ) + real_endpoint = trigger._resolved_endpoint() + path = join_paths(self._router_prefix, real_endpoint) + self._unregister_route_tables(path, trigger._methods) del self._trigger_map[trigger_id] def _init_app(self, system_app: SystemApp): From 439b5b32e2648f83436d072e2178c1855dc0c6c7 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 29 Aug 2024 12:03:14 +0800 Subject: [PATCH 03/60] feat: Support mappers in inputs and outputs --- dbgpt/app/operators/datasource.py | 29 +++++++- dbgpt/app/operators/rag.py | 19 +++++ dbgpt/core/awel/dag/base.py | 15 ++-- dbgpt/core/awel/flow/base.py | 11 +++ dbgpt/core/awel/flow/flow_factory.py | 89 ++++++++++++++++++++---- dbgpt/core/awel/trigger/http_trigger.py | 20 ++++++ dbgpt/serve/flow/service/service.py | 4 +- examples/awel/awel_flow_ui_components.py | 4 +- 8 files changed, 169 insertions(+), 22 deletions(-) diff --git a/dbgpt/app/operators/datasource.py b/dbgpt/app/operators/datasource.py index 7fe16feaa..320df8d22 100644 --- a/dbgpt/app/operators/datasource.py +++ b/dbgpt/app/operators/datasource.py @@ -4,6 +4,7 @@ from dbgpt._private.config import Config from dbgpt.agent.resource.database import DBResource +from dbgpt.core import Chunk from dbgpt.core.awel import DAGContext, MapOperator from dbgpt.core.awel.flow import ( TAGS_ORDER_HIGH, @@ -193,6 +194,19 @@ async def save_view_message(self, dag_ctx: DAGContext, view: str): class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]): """Retrieve the table schemas from the datasource.""" + _share_data_key = "__datasource_retriever_chunks__" + + class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]): + async def map(self, context: HOContextBody) -> List[Chunk]: + schema_info = await self.current_dag_context.get_from_share_data( + HODatasourceRetrieverOperator._share_data_key + ) + if isinstance(schema_info, list): + chunks = [Chunk(content=table_info) for table_info in schema_info] + else: + chunks = [Chunk(content=schema_info)] + return chunks + metadata = ViewMetadata( label=_("Datasource Retriever Operator"), name="higher_order_datasource_retriever_operator", @@ -207,7 +221,17 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]): _PARAMETER_CONTEXT_KEY.new(), ], inputs=[_INPUTS_QUESTION.new()], - outputs=[_OUTPUTS_CONTEXT.new()], + outputs=[ + _OUTPUTS_CONTEXT.new(), + IOField.build_from( + _("Retrieved schema chunks"), + "chunks", + Chunk, + is_list=True, + description=_("The retrieved schema chunks from the datasource"), + mappers=[ChunkMapper], + ), + ], tags={"order": TAGS_ORDER_HIGH}, ) @@ -239,6 +263,9 @@ async def map(self, question: str) -> HOContextBody: db=db_name, question=question, ) + await self.current_dag_context.save_to_share_data( + self._share_data_key, schema_info + ) context = self._prompt_template.format( db_name=db_name, table_info=schema_info, diff --git a/dbgpt/app/operators/rag.py b/dbgpt/app/operators/rag.py index 79d166ac0..d7fa75b24 100644 --- a/dbgpt/app/operators/rag.py +++ b/dbgpt/app/operators/rag.py @@ -1,6 +1,7 @@ from typing import List, Optional from dbgpt._private.config import Config +from dbgpt.core import Chunk from dbgpt.core.awel import MapOperator from dbgpt.core.awel.flow import ( TAGS_ORDER_HIGH, @@ -93,6 +94,15 @@ def _load_space_name() -> List[OptionValue]: class HOKnowledgeOperator(MapOperator[str, HOContextBody]): + _share_data_key = "_higher_order_knowledge_operator_retriever_chunks" + + class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]): + async def map(self, context: HOContextBody) -> List[Chunk]: + chunks = await self.current_dag_context.get_from_share_data( + HOKnowledgeOperator._share_data_key + ) + return chunks + metadata = ViewMetadata( label=_("Knowledge Operator"), name="higher_order_knowledge_operator", @@ -122,6 +132,14 @@ class HOKnowledgeOperator(MapOperator[str, HOContextBody]): ], outputs=[ _OUTPUTS_CONTEXT.new(), + IOField.build_from( + _("Chunks"), + "chunks", + Chunk, + is_list=True, + description=_("The retrieved chunks from the knowledge space"), + mappers=[ChunkMapper], + ), ], tags={"order": TAGS_ORDER_HIGH}, ) @@ -185,6 +203,7 @@ async def map(self, query: str) -> HOContextBody: chunks = await self._space_retriever.aretrieve_with_scores( query, self._score_threshold ) + await self.current_dag_context.save_to_share_data(self._share_data_key, chunks) return HOContextBody( context_key=self._context_key, context=[chunk.content for chunk in chunks], diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index ffe6a7b0e..2f3521d24 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -619,6 +619,7 @@ def __init__( self._node_name_to_ids: Dict[str, str] = node_name_to_ids self._event_loop_task_id = event_loop_task_id self._dag_variables = dag_variables + self._share_data_lock = asyncio.Lock() @property def _task_outputs(self) -> Dict[str, TaskContext]: @@ -680,8 +681,9 @@ async def get_from_share_data(self, key: str) -> Any: Returns: Any: The share data, you can cast it to the real type """ - logger.debug(f"Get share data by key {key} from {id(self._share_data)}") - return self._share_data.get(key) + async with self._share_data_lock: + logger.debug(f"Get share data by key {key} from {id(self._share_data)}") + return self._share_data.get(key) async def save_to_share_data( self, key: str, data: Any, overwrite: bool = False @@ -694,10 +696,11 @@ async def save_to_share_data( overwrite (bool): Whether overwrite the share data if the key already exists. Defaults to None. """ - if key in self._share_data and not overwrite: - raise ValueError(f"Share data key {key} already exists") - logger.debug(f"Save share data by key {key} to {id(self._share_data)}") - self._share_data[key] = data + async with self._share_data_lock: + if key in self._share_data and not overwrite: + raise ValueError(f"Share data key {key} already exists") + logger.debug(f"Save share data by key {key} to {id(self._share_data)}") + self._share_data[key] = data async def get_task_share_data(self, task_name: str, key: str) -> Any: """Get share data by task name and key. diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index db8bbcb84..4e691ed08 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -687,6 +687,10 @@ class IOField(Resource): " True", examples=[0, 1, 2], ) + mappers: Optional[List[str]] = Field( + default=None, + description="The mappers of the field, transform the field to the target type", + ) @classmethod def build_from( @@ -698,10 +702,16 @@ def build_from( is_list: bool = False, dynamic: bool = False, dynamic_minimum: int = 0, + mappers: Optional[Union[Type, List[Type]]] = None, ): """Build the resource from the type.""" type_name = type.__qualname__ type_cls = _get_type_name(type) + # TODO: Check the mapper instance can be created without required + # parameters. + if mappers and not isinstance(mappers, list): + mappers = [mappers] + mappers_cls = [_get_type_name(m) for m in mappers] if mappers else None return cls( label=label, name=name, @@ -711,6 +721,7 @@ def build_from( description=description or label, dynamic=dynamic, dynamic_minimum=dynamic_minimum, + mappers=mappers_cls, ) @model_validator(mode="before") diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index fe7a83f50..544adbf8d 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -1,10 +1,11 @@ """Build AWEL DAGs from serialized data.""" +import dataclasses import logging import uuid from contextlib import suppress from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast +from typing import Any, Dict, List, Literal, Optional, Type, Union, cast from typing_extensions import Annotated @@ -565,6 +566,17 @@ def parse_variables( return [FlowVariables(**v) for v in variables] +@dataclasses.dataclass +class _KeyToNodeItem: + """Key to node item.""" + + key: str + source_order: int + target_order: int + mappers: List[str] + edge_index: int + + class FlowFactory: """Flow factory.""" @@ -580,8 +592,10 @@ def build(self, flow_panel: FlowPanel) -> DAG: key_to_operator_nodes: Dict[str, FlowNodeData] = {} key_to_resource_nodes: Dict[str, FlowNodeData] = {} key_to_resource: Dict[str, ResourceMetadata] = {} - key_to_downstream: Dict[str, List[Tuple[str, int, int]]] = {} - key_to_upstream: Dict[str, List[Tuple[str, int, int]]] = {} + # Record current node's downstream + key_to_downstream: Dict[str, List[_KeyToNodeItem]] = {} + # Record current node's upstream + key_to_upstream: Dict[str, List[_KeyToNodeItem]] = {} key_to_upstream_node: Dict[str, List[FlowNodeData]] = {} for node in flow_data.nodes: key = node.id @@ -595,7 +609,7 @@ def build(self, flow_panel: FlowPanel) -> DAG: key_to_resource_nodes[key] = node key_to_resource[key] = node.data - for edge in flow_data.edges: + for edge_index, edge in enumerate(flow_data.edges): source_key = edge.source target_key = edge.target source_node: FlowNodeData | None = key_to_operator_nodes.get( @@ -615,12 +629,37 @@ def build(self, flow_panel: FlowPanel) -> DAG: if source_node.data.is_operator and target_node.data.is_operator: # Operator to operator. + mappers = [] + for i, out in enumerate(source_node.data.outputs): + if i != edge.source_order: + continue + if out.mappers: + # Current edge is a mapper edge, find the mappers. + mappers = out.mappers + # Note: Not support mappers in the inputs of the target node now. + downstream = key_to_downstream.get(source_key, []) - downstream.append((target_key, edge.source_order, edge.target_order)) + downstream.append( + _KeyToNodeItem( + key=target_key, + source_order=edge.source_order, + target_order=edge.target_order, + mappers=mappers, + edge_index=edge_index, + ) + ) key_to_downstream[source_key] = downstream upstream = key_to_upstream.get(target_key, []) - upstream.append((source_key, edge.source_order, edge.target_order)) + upstream.append( + _KeyToNodeItem( + key=source_key, + source_order=edge.source_order, + target_order=edge.target_order, + mappers=mappers, + edge_index=edge_index, + ) + ) key_to_upstream[target_key] = upstream elif not source_node.data.is_operator and target_node.data.is_operator: # Resource to operator. @@ -678,10 +717,10 @@ def build(self, flow_panel: FlowPanel) -> DAG: # Sort the keys by the order of the nodes. for key, value in key_to_downstream.items(): # Sort by source_order. - key_to_downstream[key] = sorted(value, key=lambda x: x[1]) + key_to_downstream[key] = sorted(value, key=lambda x: x.source_order) for key, value in key_to_upstream.items(): # Sort by target_order. - key_to_upstream[key] = sorted(value, key=lambda x: x[2]) + key_to_upstream[key] = sorted(value, key=lambda x: x.target_order) sorted_key_to_resource_nodes = list(key_to_resource_nodes.values()) sorted_key_to_resource_nodes = sorted( @@ -779,8 +818,8 @@ def build_dag( self, flow_panel: FlowPanel, key_to_tasks: Dict[str, DAGNode], - key_to_downstream: Dict[str, List[Tuple[str, int, int]]], - key_to_upstream: Dict[str, List[Tuple[str, int, int]]], + key_to_downstream: Dict[str, List[_KeyToNodeItem]], + key_to_upstream: Dict[str, List[_KeyToNodeItem]], dag_id: Optional[str] = None, ) -> DAG: """Build the DAG.""" @@ -827,7 +866,8 @@ def build_dag( # This upstream has been sorted according to the order in the downstream # So we just need to connect the task to the upstream. - for upstream_key, _, _ in upstream: + for up_item in upstream: + upstream_key = up_item.key # Just one direction. upstream_task = key_to_tasks.get(upstream_key) if not upstream_task: @@ -838,7 +878,13 @@ def build_dag( upstream_task.set_node_id(dag._new_node_id()) if upstream_task is None: raise ValueError("Unable to find upstream task.") - upstream_task >> task + tasks = _build_mapper_operators(dag, up_item.mappers) + tasks.append(task) + last_task = upstream_task + for t in tasks: + # Connect the task to the upstream task. + last_task >> t + last_task = t return dag def pre_load_requirements(self, flow_panel: FlowPanel): @@ -945,6 +991,23 @@ def _topological_sort( return key_to_order +def _build_mapper_operators(dag: DAG, mappers: List[str]) -> List[DAGNode]: + from .base import _get_type_cls + + tasks = [] + for mapper in mappers: + try: + mapper_cls = _get_type_cls(mapper) + task = mapper_cls() + if not task._node_id: + task.set_node_id(dag._new_node_id()) + tasks.append(task) + except Exception as e: + err_msg = f"Unable to build mapper task: {mapper}, error: {e}" + raise FlowMetadataException(err_msg) + return tasks + + def fill_flow_panel(flow_panel: FlowPanel): """Fill the flow panel with the latest metadata. @@ -973,6 +1036,7 @@ def fill_flow_panel(flow_panel: FlowPanel): i.dynamic = new_param.dynamic i.is_list = new_param.is_list i.dynamic_minimum = new_param.dynamic_minimum + i.mappers = new_param.mappers for i in node.data.outputs: if i.name in output_parameters: new_param = output_parameters[i.name] @@ -981,6 +1045,7 @@ def fill_flow_panel(flow_panel: FlowPanel): i.dynamic = new_param.dynamic i.is_list = new_param.is_list i.dynamic_minimum = new_param.dynamic_minimum + i.mappers = new_param.mappers else: data = cast(ResourceMetadata, node.data) key = data.get_origin_id() diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 6e17be15e..fd503f566 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -945,6 +945,16 @@ def __init__( class CommonLLMHttpTrigger(HttpTrigger): """Common LLM http trigger for AWEL.""" + class MessagesOutputMapper(MapOperator[CommonLLMHttpRequestBody, str]): + """Messages output mapper.""" + + async def map(self, request_body: CommonLLMHttpRequestBody) -> str: + """Map the request body to messages.""" + if isinstance(request_body.messages, str): + return request_body.messages + else: + raise ValueError("Messages to be transformed is not a string") + metadata = ViewMetadata( label=_("Common LLM Http Trigger"), name="common_llm_http_trigger", @@ -965,6 +975,16 @@ class CommonLLMHttpTrigger(HttpTrigger): "LLM http body" ), ), + IOField.build_from( + _("Request String Messages"), + "request_string_messages", + str, + description=_( + "The request string messages of the API endpoint, parsed from " + "'messages' field of the request body" + ), + mappers=[MessagesOutputMapper], + ), ], parameters=[ Parameter.build_from( diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 3aac0b24c..cc9c00341 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -388,7 +388,9 @@ def get_list_by_page( Returns: List[ServerResponse]: The response """ - page_result = self.dao.get_list_page(request, page, page_size) + page_result = self.dao.get_list_page( + request, page, page_size, desc_order_column=ServeEntity.gmt_modified.name + ) for item in page_result.items: metadata = self.dag_manager.get_dag_metadata( item.dag_id, alias_name=item.uid diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index db92ca09d..2db9607ea 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -852,7 +852,7 @@ class ExampleFlowUploadOperator(MapOperator[str, str]): ui=ui.UIUpload( max_file_size=1024 * 1024 * 100, up_event="button_click", - file_types=["image/*", "*.pdf"], + file_types=["image/*", ".pdf"], drag=True, attr=ui.UIUpload.UIAttribute(max_count=5), ), @@ -897,7 +897,7 @@ async def map(self, user_name: str) -> str: files_metadata = await self.blocking_func_to_async( self._parse_files_metadata, fsc ) - files_metadata_str = json.dumps(files_metadata, ensure_ascii=False) + files_metadata_str = json.dumps(files_metadata, ensure_ascii=False, indent=4) return "Your name is %s, and you files are %s." % ( user_name, files_metadata_str, From 494eb587dd89dc4a7347dc8c41e736c64dc1b2ff Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 29 Aug 2024 19:37:45 +0800 Subject: [PATCH 04/60] feat: Support variables query API --- dbgpt/core/interface/variables.py | 1 + dbgpt/model/cluster/client.py | 6 +- dbgpt/serve/flow/api/endpoints.py | 63 +++++ dbgpt/serve/flow/api/schemas.py | 13 +- dbgpt/serve/flow/api/variables_provider.py | 134 +++++++++++ dbgpt/serve/flow/service/service.py | 2 +- dbgpt/serve/flow/service/variables_service.py | 224 +++++++++++++++++- dbgpt/util/pagination_utils.py | 26 ++ dbgpt/util/tests/test_pagination_utils.py | 84 +++++++ 9 files changed, 544 insertions(+), 9 deletions(-) create mode 100644 dbgpt/util/tests/test_pagination_utils.py diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py index 22e035d52..5d538ad25 100644 --- a/dbgpt/core/interface/variables.py +++ b/dbgpt/core/interface/variables.py @@ -31,6 +31,7 @@ BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets" BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms" BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings" +# Not implemented yet BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers" BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources" BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents" diff --git a/dbgpt/model/cluster/client.py b/dbgpt/model/cluster/client.py index d58645cb7..f4141cc23 100644 --- a/dbgpt/model/cluster/client.py +++ b/dbgpt/model/cluster/client.py @@ -27,7 +27,7 @@ name="auto_convert_message", type=bool, optional=True, - default=False, + default=True, description=_( "Whether to auto convert the messages that are not supported " "by the LLM to a compatible format" @@ -128,7 +128,7 @@ async def count_token(self, model: str, prompt: str) -> int: name="auto_convert_message", type=bool, optional=True, - default=False, + default=True, description=_( "Whether to auto convert the messages that are not supported " "by the LLM to a compatible format" @@ -158,7 +158,7 @@ class RemoteLLMClient(DefaultLLMClient): def __init__( self, controller_address: str = "http://127.0.0.1:8000", - auto_convert_message: bool = False, + auto_convert_message: bool = True, ): """Initialize the RemoteLLMClient.""" from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index c6148f994..3060956d5 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -21,6 +21,7 @@ RefreshNodeRequest, ServeRequest, ServerResponse, + VariablesKeyResponse, VariablesRequest, VariablesResponse, ) @@ -359,6 +360,62 @@ async def update_variables( return Result.succ(res) +@router.get( + "/variables", + response_model=Result[PaginationResult[VariablesResponse]], + dependencies=[Depends(check_api_key)], +) +async def get_variables_by_keys( + key: str = Query(..., description="variable key"), + scope: Optional[str] = Query(default=None, description="scope"), + scope_key: Optional[str] = Query(default=None, description="scope key"), + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + page: int = Query(default=1, description="current page"), + page_size: int = Query(default=20, description="page size"), +) -> Result[PaginationResult[VariablesResponse]]: + """Get the variables by keys + + Returns: + VariablesResponse: The response + """ + res = await get_variable_service().get_list_by_page( + key, + scope, + scope_key, + user_name, + sys_code, + page, + page_size, + ) + return Result.succ(res) + + +@router.get( + "/variables/keys", + response_model=Result[List[VariablesKeyResponse]], + dependencies=[Depends(check_api_key)], +) +async def get_variables_keys( + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + category: Optional[str] = Query(default=None, description="category"), +) -> Result[List[VariablesKeyResponse]]: + """Get the variable keys + + Returns: + VariablesKeyResponse: The response + """ + res = await blocking_func_to_async( + global_system_app, + get_variable_service().list_keys, + user_name, + sys_code, + category, + ) + return Result.succ(res) + + @router.post("/flow/debug", dependencies=[Depends(check_api_key)]) async def debug_flow( flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service) @@ -477,10 +534,13 @@ async def import_flow( def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" from .variables_provider import ( + BuiltinAgentsVariablesProvider, BuiltinAllSecretVariablesProvider, BuiltinAllVariablesProvider, + BuiltinDatasourceVariablesProvider, BuiltinEmbeddingsVariablesProvider, BuiltinFlowVariablesProvider, + BuiltinKnowledgeSpacesVariablesProvider, BuiltinLLMVariablesProvider, BuiltinNodeVariablesProvider, ) @@ -494,4 +554,7 @@ def init_endpoints(system_app: SystemApp) -> None: system_app.register(BuiltinAllSecretVariablesProvider) system_app.register(BuiltinLLMVariablesProvider) system_app.register(BuiltinEmbeddingsVariablesProvider) + system_app.register(BuiltinDatasourceVariablesProvider) + system_app.register(BuiltinAgentsVariablesProvider) + system_app.register(BuiltinKnowledgeSpacesVariablesProvider) global_system_app = system_app diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index cf82de982..6053dd885 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -2,7 +2,11 @@ from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core.awel import CommonLLMHttpRequestBody -from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest +from dbgpt.core.awel.flow.flow_factory import ( + FlowPanel, + VariablesRequest, + _VariablesRequestBase, +) from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest from ..config import SERVE_APP_NAME_HUMP @@ -28,6 +32,13 @@ class VariablesResponse(VariablesRequest): ) +class VariablesKeyResponse(_VariablesRequestBase): + """Variables Key response model. + + Just include the key, for select options in the frontend. + """ + + class RefreshNodeRequest(BaseModel): """Flow response model""" diff --git a/dbgpt/serve/flow/api/variables_provider.py b/dbgpt/serve/flow/api/variables_provider.py index 4728f80e6..27ed63bf5 100644 --- a/dbgpt/serve/flow/api/variables_provider.py +++ b/dbgpt/serve/flow/api/variables_provider.py @@ -1,9 +1,12 @@ from typing import List, Literal, Optional from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_AGENTS, + BUILTIN_VARIABLES_CORE_DATASOURCES, BUILTIN_VARIABLES_CORE_EMBEDDINGS, BUILTIN_VARIABLES_CORE_FLOW_NODES, BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES, BUILTIN_VARIABLES_CORE_LLMS, BUILTIN_VARIABLES_CORE_SECRETS, BUILTIN_VARIABLES_CORE_VARIABLES, @@ -54,6 +57,7 @@ def get_variables( scope_key=scope_key, sys_code=sys_code, user_name=user_name, + description=flow.description, ) ) return variables @@ -91,6 +95,7 @@ def get_variables( scope_key=scope_key, sys_code=sys_code, user_name=user_name, + description=metadata.get("description"), ) ) return variables @@ -122,10 +127,14 @@ def _get_variables_from_db( name=var.name, label=var.label, value=var.value, + category=var.category, + value_type=var.value_type, scope=scope, scope_key=scope_key, sys_code=sys_code, user_name=user_name, + enabled=1 if var.enabled else 0, + description=var.description, ) ) return variables @@ -258,3 +267,128 @@ async def async_get_variables( return await self._get_models( key, scope, scope_key, sys_code, user_name, "text2vec" ) + + +class BuiltinDatasourceVariablesProvider(BuiltinVariablesProvider): + """Builtin datasource variables provider. + + Provide all datasource variables by variables "${dbgpt.core.datasource}" + """ + + name = BUILTIN_VARIABLES_CORE_DATASOURCES + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.serve.datasource.service.service import ( + DatasourceServeResponse, + Service, + ) + + all_datasource: List[DatasourceServeResponse] = Service.get_instance( + self.system_app + ).list() + + variables = [] + for datasource in all_datasource: + label = f"[{datasource.db_type}]{datasource.db_name}" + variables.append( + StorageVariables( + key=key, + name=datasource.db_name, + label=label, + value=datasource.db_name, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + description=datasource.comment, + ) + ) + return variables + + +class BuiltinAgentsVariablesProvider(BuiltinVariablesProvider): + """Builtin agents variables provider. + + Provide all agents variables by variables "${dbgpt.core.agent.agents}" + """ + + name = BUILTIN_VARIABLES_CORE_AGENTS + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.agent.core.agent_manage import get_agent_manager + + agent_manager = get_agent_manager(self.system_app) + agents = agent_manager.list_agents() + variables = [] + for agent in agents: + variables.append( + StorageVariables( + key=key, + name=agent["name"], + label=agent["desc"], + value=agent["name"], + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + description=agent["desc"], + ) + ) + return variables + + +class BuiltinKnowledgeSpacesVariablesProvider(BuiltinVariablesProvider): + """Builtin knowledge variables provider. + + Provide all knowledge variables by variables "${dbgpt.core.knowledge_spaces}" + """ + + name = BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.serve.rag.service.service import Service, SpaceServeRequest + + # TODO: Query with user_name and sys_code + knowledge_list = Service.get_instance(self.system_app).get_list( + SpaceServeRequest() + ) + variables = [] + for k in knowledge_list: + variables.append( + StorageVariables( + key=key, + name=k.name, + label=k.name, + value=k.name, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + description=k.desc, + ) + ) + return variables diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index cc9c00341..30a8a06c0 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -230,7 +230,7 @@ def load_dag_from_dbgpts(self, is_first_load: bool = False): continue # Set state to DEPLOYED flow.state = State.DEPLOYED - exist_inst = self.get({"name": flow.name}) + exist_inst = self.dao.get_one({"name": flow.name}) if not exist_inst: self.create_and_save_dag(flow, save_failed_flow=True) elif is_first_load or exist_inst.state != State.RUNNING: diff --git a/dbgpt/serve/flow/service/variables_service.py b/dbgpt/serve/flow/service/variables_service.py index 09e2a16b0..fbb4cc9b9 100644 --- a/dbgpt/serve/flow/service/variables_service.py +++ b/dbgpt/serve/flow/service/variables_service.py @@ -1,10 +1,25 @@ from typing import List, Optional from dbgpt import SystemApp -from dbgpt.core.interface.variables import StorageVariables, VariablesProvider -from dbgpt.serve.core import BaseService +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_AGENTS, + BUILTIN_VARIABLES_CORE_DATASOURCES, + BUILTIN_VARIABLES_CORE_EMBEDDINGS, + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES, + BUILTIN_VARIABLES_CORE_LLMS, + BUILTIN_VARIABLES_CORE_RERANKERS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, + StorageVariables, + VariablesProvider, +) +from dbgpt.serve.core import BaseService, blocking_func_to_async +from dbgpt.util import PaginationResult +from dbgpt.util.i18n_utils import _ -from ..api.schemas import VariablesRequest, VariablesResponse +from ..api.schemas import VariablesKeyResponse, VariablesRequest, VariablesResponse from ..config import ( SERVE_CONFIG_KEY_PREFIX, SERVE_VARIABLES_SERVICE_COMPONENT_NAME, @@ -12,6 +27,93 @@ ) from ..models.models import VariablesDao, VariablesEntity +BUILTIN_VARIABLES = [ + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_FLOWS, + label=_("All AWEL Flows"), + description=_("Fetch all AWEL flows in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_FLOW_NODES, + label=_("All AWEL Flow Nodes"), + description=_("Fetch all AWEL flow nodes in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_VARIABLES, + label=_("All Variables"), + description=_("Fetch all variables in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_SECRETS, + label=_("All Secrets"), + description=_("Fetch all secrets in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_LLMS, + label=_("All LLMs"), + description=_("Fetch all LLMs in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_EMBEDDINGS, + label=_("All Embeddings"), + description=_("Fetch all embeddings models in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_RERANKERS, + label=_("All Rerankers"), + description=_("Fetch all rerankers in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_DATASOURCES, + label=_("All Data Sources"), + description=_("Fetch all data sources in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_AGENTS, + label=_("All Agents"), + description=_("Fetch all agents in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES, + label=_("All Knowledge Spaces"), + description=_("Fetch all knowledge spaces in the system"), + value_type="str", + category="common", + scope="global", + ), +] + + +def _is_builtin_variable(key: str) -> bool: + return key in [v.key for v in BUILTIN_VARIABLES] + class VariablesService( BaseService[VariablesEntity, VariablesRequest, VariablesResponse] @@ -148,5 +250,119 @@ def update(self, _: int, request: VariablesRequest) -> VariablesResponse: return self.dao.get_one(query) def list_all_variables(self, category: str = "common") -> List[VariablesResponse]: - """List all variables.""" + """List all variables. + + Please note that this method will return all variables in the system, it may + be a large list. + """ return self.dao.get_list({"enabled": True, "category": category}) + + def list_keys( + self, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + category: Optional[str] = None, + ) -> List[VariablesKeyResponse]: + """List all keys.""" + results = [] + + # TODO: More high performance way to get the keys + all_db_variables = self.dao.get_list( + { + "enabled": True, + "category": category, + "user_name": user_name, + "sys_code": sys_code, + } + ) + if not user_name: + # Only return the keys that are not user specific + all_db_variables = [v for v in all_db_variables if not v.user_name] + if not sys_code: + # Only return the keys that are not system specific + all_db_variables = [v for v in all_db_variables if not v.sys_code] + key_to_db_variable = {} + for db_variable in all_db_variables: + key = db_variable.key + if key not in key_to_db_variable: + key_to_db_variable[key] = db_variable + + # Append all builtin variables to the results + results.extend(BUILTIN_VARIABLES) + + # Append all db variables to the results + for key, db_variable in key_to_db_variable.items(): + results.append( + VariablesKeyResponse( + key=key, + label=db_variable.label, + description=db_variable.description, + value_type=db_variable.value_type, + category=db_variable.category, + scope=db_variable.scope, + scope_key=db_variable.scope_key, + ) + ) + return results + + async def get_list_by_page( + self, + key: str, + scope: Optional[str] = None, + scope_key: Optional[str] = None, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + page: int = 1, + page_size: int = 20, + ) -> PaginationResult[VariablesResponse]: + """Get a list of variables by page.""" + if not _is_builtin_variable(key): + query = { + "key": key, + "scope": scope, + "scope_key": scope_key, + "user_name": user_name, + "sys_code": sys_code, + } + return await blocking_func_to_async( + self._system_app, + self.dao.get_list_page, + query, + page, + page_size, + desc_order_column="gmt_modified", + ) + else: + variables: List[ + StorageVariables + ] = await self.variables_provider.async_get_variables( + key=key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + result_variables = [] + for entity in variables: + result_variables.append( + VariablesResponse( + id=-1, + key=entity.key, + name=entity.name, + label=entity.label, + value=entity.value, + value_type=entity.value_type, + category=entity.category, + scope=entity.scope, + scope_key=entity.scope_key, + enabled=True if entity.enabled == 1 else False, + user_name=entity.user_name, + sys_code=entity.sys_code, + description=entity.description, + ) + ) + return PaginationResult.build_from_all( + result_variables, + page, + page_size, + ) diff --git a/dbgpt/util/pagination_utils.py b/dbgpt/util/pagination_utils.py index f8c20ccd9..5b67333c6 100644 --- a/dbgpt/util/pagination_utils.py +++ b/dbgpt/util/pagination_utils.py @@ -15,3 +15,29 @@ class PaginationResult(BaseModel, Generic[T]): total_pages: int = Field(..., description="total number of pages") page: int = Field(..., description="Current page number") page_size: int = Field(..., description="Number of items per page") + + @classmethod + def build_from_all( + cls, all_items: List[T], page: int, page_size: int + ) -> "PaginationResult[T]": + """Build a pagination result from all items""" + if page < 1: + page = 1 + if page_size < 1: + page_size = 1 + total_count = len(all_items) + total_pages = ( + (total_count + page_size - 1) // page_size if total_count > 0 else 0 + ) + page = max(1, min(page, total_pages)) if total_pages > 0 else 0 + start_index = (page - 1) * page_size if page > 0 else 0 + end_index = min(start_index + page_size, total_count) + items = all_items[start_index:end_index] + + return cls( + items=items, + total_count=total_count, + total_pages=total_pages, + page=page, + page_size=page_size, + ) diff --git a/dbgpt/util/tests/test_pagination_utils.py b/dbgpt/util/tests/test_pagination_utils.py new file mode 100644 index 000000000..d0d2132c5 --- /dev/null +++ b/dbgpt/util/tests/test_pagination_utils.py @@ -0,0 +1,84 @@ +from dbgpt.util.pagination_utils import PaginationResult + + +def test_build_from_all_normal_case(): + items = list(range(100)) + result = PaginationResult.build_from_all(items, page=2, page_size=20) + + assert len(result.items) == 20 + assert result.items == list(range(20, 40)) + assert result.total_count == 100 + assert result.total_pages == 5 + assert result.page == 2 + assert result.page_size == 20 + + +def test_build_from_all_empty_list(): + items = [] + result = PaginationResult.build_from_all(items, page=1, page_size=5) + + assert result.items == [] + assert result.total_count == 0 + assert result.total_pages == 0 + assert result.page == 0 + assert result.page_size == 5 + + +def test_build_from_all_last_page(): + items = list(range(95)) + result = PaginationResult.build_from_all(items, page=5, page_size=20) + + assert len(result.items) == 15 + assert result.items == list(range(80, 95)) + assert result.total_count == 95 + assert result.total_pages == 5 + assert result.page == 5 + assert result.page_size == 20 + + +def test_build_from_all_page_out_of_range(): + items = list(range(50)) + result = PaginationResult.build_from_all(items, page=10, page_size=10) + + assert len(result.items) == 10 + assert result.items == list(range(40, 50)) + assert result.total_count == 50 + assert result.total_pages == 5 + assert result.page == 5 + assert result.page_size == 10 + + +def test_build_from_all_page_zero(): + items = list(range(50)) + result = PaginationResult.build_from_all(items, page=0, page_size=10) + + assert len(result.items) == 10 + assert result.items == list(range(0, 10)) + assert result.total_count == 50 + assert result.total_pages == 5 + assert result.page == 1 + assert result.page_size == 10 + + +def test_build_from_all_negative_page(): + items = list(range(50)) + result = PaginationResult.build_from_all(items, page=-1, page_size=10) + + assert len(result.items) == 10 + assert result.items == list(range(0, 10)) + assert result.total_count == 50 + assert result.total_pages == 5 + assert result.page == 1 + assert result.page_size == 10 + + +def test_build_from_all_page_size_larger_than_total(): + items = list(range(50)) + result = PaginationResult.build_from_all(items, page=1, page_size=100) + + assert len(result.items) == 50 + assert result.items == list(range(50)) + assert result.total_count == 50 + assert result.total_pages == 1 + assert result.page == 1 + assert result.page_size == 100 From 147051cefdb21bef000af04d9769ebaa4b12b7ca Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 29 Aug 2024 23:07:51 +0800 Subject: [PATCH 05/60] feat: Support query file metadatas --- dbgpt/serve/file/api/endpoints.py | 77 ++++++++++++++++++++++++++++- dbgpt/serve/file/api/schemas.py | 48 +++++++++++++++++- dbgpt/serve/file/service/service.py | 39 ++++++++++++++- 3 files changed, 159 insertions(+), 5 deletions(-) diff --git a/dbgpt/serve/file/api/endpoints.py b/dbgpt/serve/file/api/endpoints.py index 26bbb9673..d5b65bc54 100644 --- a/dbgpt/serve/file/api/endpoints.py +++ b/dbgpt/serve/file/api/endpoints.py @@ -1,3 +1,4 @@ +import asyncio import logging from functools import cache from typing import List, Optional @@ -13,7 +14,13 @@ from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from .schemas import ServeRequest, ServerResponse, UploadFileResponse +from .schemas import ( + FileMetadataBatchRequest, + FileMetadataResponse, + ServeRequest, + ServerResponse, + UploadFileResponse, +) router = APIRouter() logger = logging.getLogger(__name__) @@ -162,6 +169,74 @@ async def delete_file( return Result.succ(None) +@router.get( + "/files/metadata", + response_model=Result[FileMetadataResponse], + dependencies=[Depends(check_api_key)], +) +async def get_file_metadata( + uri: Optional[str] = Query(None, description="File URI"), + bucket: Optional[str] = Query(None, description="Bucket name"), + file_id: Optional[str] = Query(None, description="File ID"), + service: Service = Depends(get_service), +) -> Result[FileMetadataResponse]: + """Get file metadata by URI or by bucket and file_id.""" + if not uri and not (bucket and file_id): + raise HTTPException( + status_code=400, + detail="Either uri or (bucket and file_id) must be provided", + ) + + metadata = await blocking_func_to_async( + global_system_app, service.get_file_metadata, uri, bucket, file_id + ) + return Result.succ(metadata) + + +@router.post( + "/files/metadata/batch", + response_model=Result[List[FileMetadataResponse]], + dependencies=[Depends(check_api_key)], +) +async def get_files_metadata_batch( + request: FileMetadataBatchRequest, service: Service = Depends(get_service) +) -> Result[List[FileMetadataResponse]]: + """Get metadata for multiple files by URIs or bucket and file_id pairs.""" + if not request.uris and not request.bucket_file_pairs: + raise HTTPException( + status_code=400, + detail="Either uris or bucket_file_pairs must be provided", + ) + + batch_req = [] + if request.uris: + for uri in request.uris: + batch_req.append((uri, None, None)) + elif request.bucket_file_pairs: + for pair in request.bucket_file_pairs: + batch_req.append((None, pair.bucket, pair.file_id)) + else: + raise HTTPException( + status_code=400, + detail="Either uris or bucket_file_pairs must be provided", + ) + + batch_req_tasks = [ + blocking_func_to_async( + global_system_app, service.get_file_metadata, uri, bucket, file_id + ) + for uri, bucket, file_id in batch_req + ] + + metadata_list = await asyncio.gather(*batch_req_tasks) + if not metadata_list: + raise HTTPException( + status_code=404, + detail="File metadata not found", + ) + return Result.succ(metadata_list) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" global global_system_app diff --git a/dbgpt/serve/file/api/schemas.py b/dbgpt/serve/file/api/schemas.py index 911f71db3..bd8b3bbf2 100644 --- a/dbgpt/serve/file/api/schemas.py +++ b/dbgpt/serve/file/api/schemas.py @@ -1,7 +1,13 @@ # Define your Pydantic schemas here -from typing import Any, Dict +from typing import Any, Dict, List, Optional -from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict +from dbgpt._private.pydantic import ( + BaseModel, + ConfigDict, + Field, + model_to_dict, + model_validator, +) from ..config import SERVE_APP_NAME_HUMP @@ -41,3 +47,41 @@ class UploadFileResponse(BaseModel): def to_dict(self, **kwargs) -> Dict[str, Any]: """Convert the model to a dictionary""" return model_to_dict(self, **kwargs) + + +class _BucketFilePair(BaseModel): + """Bucket file pair model""" + + bucket: str = Field(..., title="The bucket of the file") + file_id: str = Field(..., title="The ID of the file") + + +class FileMetadataBatchRequest(BaseModel): + """File metadata batch request model""" + + uris: Optional[List[str]] = Field(None, title="The URIs of the files") + bucket_file_pairs: Optional[List[_BucketFilePair]] = Field( + None, title="The bucket file pairs" + ) + + @model_validator(mode="after") + def check_uris_or_bucket_file_pairs(self): + # Check if either uris or bucket_file_pairs is provided + if not (self.uris or self.bucket_file_pairs): + raise ValueError("Either uris or bucket_file_pairs must be provided") + # Check only one of uris or bucket_file_pairs is provided + if self.uris and self.bucket_file_pairs: + raise ValueError("Only one of uris or bucket_file_pairs can be provided") + return self + + +class FileMetadataResponse(BaseModel): + """File metadata model""" + + file_name: str = Field(..., title="The name of the file") + file_id: str = Field(..., title="The ID of the file") + bucket: str = Field(..., title="The bucket of the file") + uri: str = Field(..., title="The URI of the file") + file_size: int = Field(..., title="The size of the file") + user_name: Optional[str] = Field(None, title="The user name") + sys_code: Optional[str] = Field(None, title="The system code") diff --git a/dbgpt/serve/file/service/service.py b/dbgpt/serve/file/service/service.py index 13e8b6225..85940ed35 100644 --- a/dbgpt/serve/file/service/service.py +++ b/dbgpt/serve/file/service/service.py @@ -1,7 +1,7 @@ import logging from typing import BinaryIO, List, Optional, Tuple -from fastapi import UploadFile +from fastapi import HTTPException, UploadFile from dbgpt.component import BaseComponent, SystemApp from dbgpt.core.interface.file import FileMetadata, FileStorageClient, FileStorageURI @@ -10,7 +10,12 @@ from dbgpt.util.pagination_utils import PaginationResult from dbgpt.util.tracer import root_tracer, trace -from ..api.schemas import ServeRequest, ServerResponse, UploadFileResponse +from ..api.schemas import ( + FileMetadataResponse, + ServeRequest, + ServerResponse, + UploadFileResponse, +) from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..models.models import ServeDao, ServeEntity @@ -117,3 +122,33 @@ def download_file(self, bucket: str, file_id: str) -> Tuple[BinaryIO, FileMetada def delete_file(self, bucket: str, file_id: str) -> None: """Delete a file by file_id.""" self.file_storage_client.delete_file_by_id(bucket, file_id) + + def get_file_metadata( + self, + uri: Optional[str] = None, + bucket: Optional[str] = None, + file_id: Optional[str] = None, + ) -> Optional[FileMetadataResponse]: + """Get the metadata of a file by file_id.""" + if uri: + parsed_uri = FileStorageURI.parse(uri) + bucket, file_id = parsed_uri.bucket, parsed_uri.file_id + if not (bucket and file_id): + raise ValueError("Either uri or bucket and file_id must be provided.") + metadata = self.file_storage_client.storage_system.get_file_metadata( + bucket, file_id + ) + if not metadata: + raise HTTPException( + status_code=404, + detail=f"File metadata not found: bucket={bucket}, file_id={file_id}, uri={uri}", + ) + return FileMetadataResponse( + file_name=metadata.file_name, + file_id=metadata.file_id, + bucket=metadata.bucket, + uri=metadata.uri, + file_size=metadata.file_size, + user_name=metadata.user_name, + sys_code=metadata.sys_code, + ) From 08749a01104b98fdb3dfee956fc2e55c47fa5b48 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 30 Aug 2024 07:24:22 +0800 Subject: [PATCH 06/60] feat: Support dynamic parameters --- dbgpt/core/awel/flow/base.py | 107 +++++++++++----- examples/awel/awel_flow_ui_components.py | 155 ++++++++++++++++++++++- 2 files changed, 228 insertions(+), 34 deletions(-) diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 4e691ed08..99aa77c8b 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -36,6 +36,8 @@ } _BASIC_TYPES = [str, int, float, bool, dict, list, set] +_DYNAMIC_PARAMETER_TYPES = [str, int, float, bool] +DefaultParameterType = Union[str, int, float, bool, None] T = TypeVar("T", bound="ViewMixin") TM = TypeVar("TM", bound="TypeMetadata") @@ -292,9 +294,6 @@ def get_category(cls, value: Type[Any]) -> "ParameterCategory": return cls.RESOURCER -DefaultParameterType = Union[str, int, float, bool, None] - - class TypeMetadata(BaseModel): """The metadata of the type.""" @@ -313,7 +312,23 @@ def new(self: TM) -> TM: return self.__class__(**self.model_dump(exclude_defaults=True)) -class Parameter(TypeMetadata, Serializable): +class BaseDynamic(BaseModel): + """The base dynamic field.""" + + dynamic: bool = Field( + default=False, + description="Whether current field is dynamic", + examples=[True, False], + ) + dynamic_minimum: int = Field( + default=0, + description="The minimum count of the dynamic field, only valid when dynamic is" + " True", + examples=[0, 1, 2], + ) + + +class Parameter(BaseDynamic, TypeMetadata, Serializable): """Parameter for build operator.""" label: str = Field( @@ -332,11 +347,6 @@ class Parameter(TypeMetadata, Serializable): description="The category of the parameter", examples=["common", "resource"], ) - # resource_category: Optional[str] = Field( - # default=None, - # description="The category of the resource, just for resource type", - # examples=["llm_client", "common"], - # ) resource_type: ResourceType = Field( default=ResourceType.INSTANCE, description="The type of the resource, just for resource type", @@ -389,6 +399,17 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: values[k] = handled_v return values + @model_validator(mode="after") + def check_parameters(self) -> "Parameter": + """Check the parameters.""" + if self.dynamic and not self.is_list: + raise FlowMetadataException("Dynamic parameter must be list.") + if self.dynamic and self.dynamic_minimum < 0: + raise FlowMetadataException( + "Dynamic minimum must be greater then or equal to 0." + ) + return self + @classmethod def _covert_to_real_type(cls, type_cls: str, v: Any, is_list: bool) -> Any: def _parse_single_value(vv: Any) -> Any: @@ -450,6 +471,8 @@ def build_from( description: Optional[str] = None, options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None, resource_type: ResourceType = ResourceType.INSTANCE, + dynamic: bool = False, + dynamic_minimum: int = 0, alias: Optional[List[str]] = None, ui: Optional[UIComponent] = None, ): @@ -461,6 +484,8 @@ def build_from( raise ValueError(f"Default value is missing for optional parameter {name}.") if not optional: default = None + if dynamic and type not in _DYNAMIC_PARAMETER_TYPES: + raise ValueError("Dynamic parameter must be str, int, float or bool.") return cls( label=label, name=name, @@ -474,6 +499,8 @@ def build_from( placeholder=placeholder, description=description or label, options=options, + dynamic=dynamic, + dynamic_minimum=dynamic_minimum, alias=alias, ui=ui, ) @@ -635,6 +662,11 @@ class BaseResource(Serializable, BaseModel): description="The label to display in UI", examples=["LLM Operator", "OpenAI LLM Client"], ) + custom_label: Optional[str] = Field( + None, + description="The custom label to display in UI", + examples=["LLM Operator", "OpenAI LLM Client"], + ) name: str = Field( ..., description="The name of the operator", @@ -668,7 +700,7 @@ class IOFiledType(str, Enum): LIST = "list" -class IOField(Resource): +class IOField(BaseDynamic, Resource): """The input or output field of the operator.""" is_list: bool = Field( @@ -676,17 +708,6 @@ class IOField(Resource): description="Whether current field is list", examples=[True, False], ) - dynamic: bool = Field( - default=False, - description="Whether current field is dynamic", - examples=[True, False], - ) - dynamic_minimum: int = Field( - default=0, - description="The minimum count of the dynamic field, only valid when dynamic is" - " True", - examples=[0, 1, 2], - ) mappers: Optional[List[str]] = Field( default=None, description="The mappers of the field, transform the field to the target type", @@ -724,18 +745,6 @@ def build_from( mappers=mappers_cls, ) - @model_validator(mode="before") - @classmethod - def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Pre fill the metadata.""" - if not isinstance(values, dict): - return values - if "dynamic" not in values: - values["dynamic"] = False - if "dynamic_minimum" not in values: - values["dynamic_minimum"] = 0 - return values - class BaseMetadata(BaseResource): """The base metadata.""" @@ -1137,6 +1146,38 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["outputs"] = new_outputs return values + @model_validator(mode="after") + def check_metadata(self) -> "ViewMetadata": + """Check the metadata.""" + if self.inputs: + for field in self.inputs: + if field.mappers: + raise ValueError("Input field can't have mappers.") + dyn_cnt, is_last_field_dynamic = 0, False + for field in self.inputs: + if field.dynamic: + dyn_cnt += 1 + is_last_field_dynamic = True + else: + if is_last_field_dynamic: + raise ValueError("Dynamic field input must be the last field.") + is_last_field_dynamic = False + if dyn_cnt > 1: + raise ValueError("Only one dynamic input field is allowed.") + if self.outputs: + dyn_cnt, is_last_field_dynamic = 0, False + for field in self.outputs: + if field.dynamic: + dyn_cnt += 1 + is_last_field_dynamic = True + else: + if is_last_field_dynamic: + raise ValueError("Dynamic field output must be the last field.") + is_last_field_dynamic = False + if dyn_cnt > 1: + raise ValueError("Only one dynamic output field is allowed.") + return self + def get_operator_key(self) -> str: """Get the operator key.""" if not self.flow_type: diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index 2db9607ea..ce411a79d 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -4,7 +4,7 @@ import logging from typing import Any, Dict, List, Optional -from dbgpt.core.awel import MapOperator +from dbgpt.core.awel import JoinOperator, MapOperator from dbgpt.core.awel.flow import ( FunctionDynamicOptions, IOField, @@ -1243,3 +1243,156 @@ def execute_code_blocks(self, code_blocks): if exitcode != 0: return exitcode, logs_all return exitcode, logs_all + + +class ExampleFlowDynamicParametersOperator(MapOperator[str, str]): + """An example flow operator that includes dynamic parameters.""" + + metadata = ViewMetadata( + label="Example Dynamic Parameters Operator", + name="example_dynamic_parameters_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes dynamic parameters.", + parameters=[ + Parameter.build_from( + "Dynamic String", + "dynamic_1", + type=str, + is_list=True, + placeholder="Please input the dynamic parameter", + description="The dynamic parameter you want to use, you can add more, " + "at least 1 parameter.", + dynamic=True, + dynamic_minimum=1, + ui=ui.UIInput(), + ), + Parameter.build_from( + "Dynamic Integer", + "dynamic_2", + type=int, + is_list=True, + placeholder="Please input the dynamic parameter", + description="The dynamic parameter you want to use, you can add more, " + "at least 0 parameter.", + dynamic=True, + dynamic_minimum=0, + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Dynamic", + "dynamic", + str, + description="User's selected dynamic.", + ), + ], + ) + + def __init__(self, dynamic_1: List[str], dynamic_2: List[int], **kwargs): + super().__init__(**kwargs) + if not dynamic_1: + raise ValueError("The dynamic string is empty.") + self.dynamic_1 = dynamic_1 + self.dynamic_2 = dynamic_2 + + async def map(self, user_name: str) -> str: + """Map the user name to the dynamic.""" + return "Your name is %s, and your dynamic is %s." % ( + user_name, + f"dynamic_1: {self.dynamic_1}, dynamic_2: {self.dynamic_2}", + ) + + +class ExampleFlowDynamicOutputsOperator(MapOperator[str, str]): + """An example flow operator that includes dynamic outputs.""" + + metadata = ViewMetadata( + label="Example Dynamic Outputs Operator", + name="example_dynamic_outputs_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes dynamic outputs.", + parameters=[], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Dynamic", + "dynamic", + str, + description="User's selected dynamic.", + dynamic=True, + dynamic_minimum=1, + ), + ], + ) + + async def map(self, user_name: str) -> str: + """Map the user name to the dynamic.""" + return "Your name is %s, this operator has dynamic outputs." % user_name + + +class ExampleFlowDynamicInputsOperator(JoinOperator[str]): + """An example flow operator that includes dynamic inputs.""" + + metadata = ViewMetadata( + label="Example Dynamic Inputs Operator", + name="example_dynamic_inputs_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes dynamic inputs.", + parameters=[], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + IOField.build_from( + "Other Inputs", + "other_inputs", + str, + description="Other inputs.", + dynamic=True, + dynamic_minimum=0, + ), + ], + outputs=[ + IOField.build_from( + "Dynamic", + "dynamic", + str, + description="User's selected dynamic.", + ), + ], + ) + + def __init__(self, **kwargs): + super().__init__(combine_function=self.join, **kwargs) + + async def join(self, user_name: str, *other_inputs: str) -> str: + """Map the user name to the dynamic.""" + if not other_inputs: + dyn_inputs = ["You have no other inputs."] + else: + dyn_inputs = [ + f"Input {i}: {input_data}" for i, input_data in enumerate(other_inputs) + ] + dyn_str = "\n".join(dyn_inputs) + return "Your name is %s, and your dynamic is %s." % ( + user_name, + f"other_inputs:\n{dyn_str}", + ) From f1e00a7502d0d44b099e34bc82c3dfa66fa366aa Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 30 Aug 2024 11:27:02 +0800 Subject: [PATCH 07/60] feat: Support rag flow template --- dbgpt/serve/dbgpts/__init__.py | 0 dbgpt/serve/flow/api/endpoints.py | 29 +- dbgpt/serve/flow/service/service.py | 59 + .../en/rag-chat-awel-flow-template.json | 1088 +++++++++++++++++ 4 files changed, 1174 insertions(+), 2 deletions(-) create mode 100644 dbgpt/serve/dbgpts/__init__.py create mode 100644 dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json diff --git a/dbgpt/serve/dbgpts/__init__.py b/dbgpt/serve/dbgpts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 3060956d5..56f73f34e 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -14,7 +14,7 @@ from dbgpt.util import PaginationResult from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig -from ..service.service import Service +from ..service.service import Service, _parse_flow_template_from_json from ..service.variables_service import VariablesService from .schemas import ( FlowDebugRequest, @@ -512,7 +512,7 @@ async def import_flow( raise HTTPException( status_code=400, detail="invalid json file, missing 'flow' key" ) - flow = ServeRequest.parse_obj(json_dict["flow"]) + flow = _parse_flow_template_from_json(json_dict["flow"]) elif file_extension == "zip": from ..service.share_utils import _parse_flow_from_zip_file @@ -531,6 +531,31 @@ async def import_flow( return Result.succ(flow) +@router.get( + "/flow/templates", + response_model=Result[PaginationResult[ServerResponse]], + dependencies=[Depends(check_api_key)], +) +async def query_flow_templates( + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + page: int = Query(default=1, description="current page"), + page_size: int = Query(default=20, description="page size"), + service: Service = Depends(get_service), +) -> Result[PaginationResult[ServerResponse]]: + """Query Flow templates.""" + + res = await blocking_func_to_async( + global_system_app, + service.get_flow_templates, + user_name, + sys_code, + page, + page_size, + ) + return Result.succ(res) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" from .variables_provider import ( diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 30a8a06c0..54d07d49a 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -1,5 +1,6 @@ import json import logging +import os from typing import AsyncIterator, List, Optional, cast import schedule @@ -399,6 +400,47 @@ def get_list_by_page( item.metadata = metadata.to_dict() return page_result + def get_flow_templates( + self, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + page: int = 1, + page_size: int = 20, + ) -> PaginationResult[ServerResponse]: + """Get a list of Flow templates + + Args: + user_name (Optional[str]): The user name + sys_code (Optional[str]): The system code + page (int): The page number + page_size (int): The page size + Returns: + List[ServerResponse]: The response + """ + local_file_templates = self._get_flow_templates_from_files() + return PaginationResult.build_from_all(local_file_templates, page, page_size) + + def _get_flow_templates_from_files(self) -> List[ServerResponse]: + """Get a list of Flow templates from files""" + user_lang = self._system_app.config.get_current_lang(default="en") + # List files in current directory + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + template_dir = os.path.join(parent_dir, "templates", user_lang) + default_template_dir = os.path.join(parent_dir, "templates", "en") + if not os.path.exists(template_dir): + template_dir = default_template_dir + templates = [] + for root, _, files in os.walk(template_dir): + for file in files: + if file.endswith(".json"): + try: + with open(os.path.join(root, file), "r") as f: + data = json.load(f) + templates.append(_parse_flow_template_from_json(data)) + except Exception as e: + logger.warning(f"Load template {file} error: {str(e)}") + return templates + async def chat_stream_flow_str( self, flow_uid: str, request: CommonLLMHttpRequestBody ) -> AsyncIterator[str]: @@ -638,3 +680,20 @@ async def _wrapper_chat_stream_flow_str( break else: yield f"data:{text}\n\n" + + +def _parse_flow_template_from_json(json_dict: dict) -> ServerResponse: + """Parse the flow from json + + Args: + json_dict (dict): The json dict + + Returns: + ServerResponse: The flow + """ + flow_json = json_dict["flow"] + flow_json["editable"] = False + del flow_json["uid"] + flow_json["state"] = State.INITIALIZING + flow_json["dag_id"] = None + return ServerResponse(**flow_json) diff --git a/dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json b/dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json new file mode 100644 index 000000000..60ff5c911 --- /dev/null +++ b/dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json @@ -0,0 +1,1088 @@ +{ + "flow": { + "uid": "21eb87d5-b63a-4f41-b2aa-28d01033344d", + "label": "RAG Chat AWEL flow template", + "name": "rag_chat_awel_flow_template", + "flow_category": "chat_flow", + "description": "An example of a RAG chat AWEL flow.", + "state": "running", + "error_message": "", + "source": "DBGPT-WEB", + "source_url": null, + "version": "0.1.1", + "define_type": "json", + "editable": true, + "user_name": null, + "sys_code": null, + "dag_id": "flow_dag_rag_chat_awel_flow_template_21eb87d5-b63a-4f41-b2aa-28d01033344d", + "gmt_created": "2024-08-30 10:48:56", + "gmt_modified": "2024-08-30 10:48:56", + "metadata": { + "sse_output": true, + "streaming_output": true, + "tags": {}, + "triggers": [ + { + "trigger_type": "http", + "path": "/api/v1/awel/trigger/templates/flow_dag_rag_chat_awel_flow_template_21eb87d5-b63a-4f41-b2aa-28d01033344d", + "methods": [ + "POST" + ], + "trigger_mode": "chat" + } + ] + }, + "variables": null, + "authors": null, + "flow_data": { + "edges": [ + { + "source": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0", + "source_order": 0, + "target": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "target_order": 0, + "id": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0|operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "source_handle": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0|outputs|0", + "target_handle": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|inputs|0", + "type": "buttonedge" + }, + { + "source": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0", + "source_order": 1, + "target": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "target_order": 0, + "id": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0|operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "source_handle": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0|outputs|1", + "target_handle": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0|inputs|0", + "type": "buttonedge" + }, + { + "source": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "source_order": 0, + "target": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "target_order": 1, + "id": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0|operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "source_handle": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0|outputs|0", + "target_handle": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|inputs|1", + "type": "buttonedge" + }, + { + "source": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0", + "source_order": 0, + "target": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "target_order": 0, + "id": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0|operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "source_handle": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0|outputs|0", + "target_handle": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|parameters|0", + "type": "buttonedge" + }, + { + "source": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "source_order": 0, + "target": "operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0", + "target_order": 0, + "id": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0", + "source_handle": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|outputs|0", + "target_handle": "operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0|inputs|0", + "type": "buttonedge" + } + ], + "viewport": { + "x": 900.5986504747431, + "y": 420.90015979869725, + "zoom": 0.6903331247004052 + }, + "nodes": [ + { + "width": 320, + "height": 632, + "id": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0", + "position": { + "x": -1164.0000230376968, + "y": -501.9869760888273, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": -1164.0000230376968, + "y": -501.9869760888273, + "zoom": 0.0 + }, + "data": { + "label": "Common LLM Http Trigger", + "custom_label": null, + "name": "common_llm_http_trigger", + "description": "Trigger your workflow by http request, and parse the request body as a common LLM http body", + "category": "trigger", + "category_label": "Trigger", + "flow_type": "operator", + "icon": null, + "documentation_url": null, + "id": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0" + }, + "operator_type": "input", + "inputs": [], + "outputs": [ + { + "type_name": "CommonLLMHttpRequestBody", + "type_cls": "dbgpt.core.awel.trigger.http_trigger.CommonLLMHttpRequestBody", + "label": "Request Body", + "custom_label": null, + "name": "request_body", + "description": "The request body of the API endpoint, parse as a common LLM http body", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "label": "Request String Messages", + "custom_label": null, + "name": "request_string_messages", + "description": "The request string messages of the API endpoint, parsed from 'messages' field of the request body", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": [ + "dbgpt.core.awel.trigger.http_trigger.CommonLLMHttpTrigger.MessagesOutputMapper" + ] + } + ], + "version": "v1", + "type_name": "CommonLLMHttpTrigger", + "type_cls": "dbgpt.core.awel.trigger.http_trigger.CommonLLMHttpTrigger", + "parameters": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "API Endpoint", + "name": "endpoint", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "/example/{dag_id}", + "placeholder": null, + "description": "The API endpoint", + "value": "/templates/{dag_id}", + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Http Methods", + "name": "methods", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "POST", + "placeholder": null, + "description": "The methods of the API endpoint", + "value": null, + "options": [ + { + "label": "HTTP Method PUT", + "name": "http_put", + "value": "PUT", + "children": null + }, + { + "label": "HTTP Method POST", + "name": "http_post", + "value": "POST", + "children": null + } + ] + }, + { + "type_name": "bool", + "type_cls": "builtins.bool", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Streaming Response", + "name": "streaming_response", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": false, + "placeholder": null, + "description": "Whether the response is streaming", + "value": false, + "options": null + }, + { + "type_name": "BaseHttpBody", + "type_cls": "dbgpt.core.awel.trigger.http_trigger.BaseHttpBody", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Http Response Body", + "name": "http_response_body", + "is_list": false, + "category": "resource", + "resource_type": "class", + "optional": true, + "default": null, + "placeholder": null, + "description": "The response body of the API endpoint", + "value": null, + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Response Media Type", + "name": "response_media_type", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The response media type", + "value": null, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Http Status Code", + "name": "status_code", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 200, + "placeholder": null, + "description": "The http status code", + "value": null, + "options": null + } + ] + } + }, + { + "width": 320, + "height": 910, + "id": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "position": { + "x": 661.094354143159, + "y": -368.93541722528227, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": 661.094354143159, + "y": -368.93541722528227, + "zoom": 0.0 + }, + "data": { + "label": "Streaming LLM Operator", + "custom_label": null, + "name": "higher_order_streaming_llm_operator", + "description": "High-level streaming LLM operator, supports multi-round conversation (conversation window, token length and no multi-round).", + "category": "llm", + "category_label": "LLM", + "flow_type": "operator", + "icon": null, + "documentation_url": null, + "id": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0" + }, + "operator_type": "map", + "inputs": [ + { + "type_name": "CommonLLMHttpRequestBody", + "type_cls": "dbgpt.core.awel.trigger.http_trigger.CommonLLMHttpRequestBody", + "label": "Common LLM Request Body", + "custom_label": null, + "name": "common_llm_request_body", + "description": "The common LLM request body.", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + }, + { + "type_name": "HOContextBody", + "type_cls": "dbgpt.app.operators.llm.HOContextBody", + "label": "Extra Context", + "custom_label": null, + "name": "extra_context", + "description": "Extra context for building prompt(Knowledge context, database schema, etc), you can add multiple context.", + "dynamic": true, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + } + ], + "outputs": [ + { + "type_name": "ModelOutput", + "type_cls": "dbgpt.core.interface.llm.ModelOutput", + "label": "Streaming Model Output", + "custom_label": null, + "name": "streaming_model_output", + "description": "The streaming model output.", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": true, + "mappers": null + } + ], + "version": "v1", + "type_name": "HOStreamingLLMOperator", + "type_cls": "dbgpt.app.operators.llm.HOStreamingLLMOperator", + "parameters": [ + { + "type_name": "ChatPromptTemplate", + "type_cls": "dbgpt.core.interface.prompt.ChatPromptTemplate", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Prompt Template", + "name": "prompt_template", + "is_list": false, + "category": "resource", + "resource_type": "instance", + "optional": false, + "default": null, + "placeholder": null, + "description": "The prompt template for the conversation.", + "value": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0", + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Model Name", + "name": "model", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The model name.", + "value": null, + "options": null + }, + { + "type_name": "LLMClient", + "type_cls": "dbgpt.core.interface.llm.LLMClient", + "dynamic": false, + "dynamic_minimum": 0, + "label": "LLM Client", + "name": "llm_client", + "is_list": false, + "category": "resource", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The LLM Client, how to connect to the LLM model, if not provided, it will use the default client deployed by DB-GPT.", + "value": null, + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "History Message Merge Mode", + "name": "history_merge_mode", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "none", + "placeholder": null, + "description": "The history merge mode, supports 'none', 'window' and 'token'. 'none': no history merge, 'window': merge by conversation window, 'token': merge by token length.", + "value": "window", + "options": [ + { + "label": "No History", + "name": "none", + "value": "none", + "children": null + }, + { + "label": "Message Window", + "name": "window", + "value": "window", + "children": null + }, + { + "label": "Token Length", + "name": "token", + "value": "token", + "children": null + } + ], + "ui": { + "refresh": false, + "refresh_depends": null, + "ui_type": "select", + "size": null, + "attr": null + } + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "User Message Key", + "name": "user_message_key", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "user_input", + "placeholder": null, + "description": "The key of the user message in your prompt, default is 'user_input'.", + "value": null, + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "History Key", + "name": "history_key", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The chat history key, with chat history message pass to prompt template, if not provided, it will parse the prompt template to get the key.", + "value": null, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Keep Start Rounds", + "name": "keep_start_rounds", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The start rounds to keep in the chat history.", + "value": 0, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Keep End Rounds", + "name": "keep_end_rounds", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The end rounds to keep in the chat history.", + "value": 10, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Max Token Limit", + "name": "max_token_limit", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 2048, + "placeholder": null, + "description": "The max token limit to keep in the chat history.", + "value": null, + "options": null + } + ] + } + }, + { + "width": 320, + "height": 774, + "id": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "position": { + "x": -781.3390803520426, + "y": 112.87665693387501, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": -781.3390803520426, + "y": 112.87665693387501, + "zoom": 0.0 + }, + "data": { + "label": "Knowledge Operator", + "custom_label": null, + "name": "higher_order_knowledge_operator", + "description": "Knowledge Operator, retrieve your knowledge(documents) from knowledge space", + "category": "rag", + "category_label": "RAG", + "flow_type": "operator", + "icon": null, + "documentation_url": null, + "id": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0" + }, + "operator_type": "map", + "inputs": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "label": "User question", + "custom_label": null, + "name": "query", + "description": "The user question to retrieve the knowledge", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + } + ], + "outputs": [ + { + "type_name": "HOContextBody", + "type_cls": "dbgpt.app.operators.llm.HOContextBody", + "label": "Retrieved context", + "custom_label": null, + "name": "context", + "description": "The retrieved context from the knowledge space", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + }, + { + "type_name": "Chunk", + "type_cls": "dbgpt.core.interface.knowledge.Chunk", + "label": "Chunks", + "custom_label": null, + "name": "chunks", + "description": "The retrieved chunks from the knowledge space", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": true, + "mappers": [ + "dbgpt.app.operators.rag.HOKnowledgeOperator.ChunkMapper" + ] + } + ], + "version": "v1", + "type_name": "HOKnowledgeOperator", + "type_cls": "dbgpt.app.operators.rag.HOKnowledgeOperator", + "parameters": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Knowledge Space Name", + "name": "knowledge_space", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": false, + "default": null, + "placeholder": null, + "description": "The name of the knowledge space", + "value": "k_cmd2", + "options": [ + { + "label": "k_cmd2", + "name": "k_cmd2", + "value": "k_cmd2", + "children": null + }, + { + "label": "f5", + "name": "f5", + "value": "f5", + "children": null + }, + { + "label": "f4", + "name": "f4", + "value": "f4", + "children": null + }, + { + "label": "t333", + "name": "t333", + "value": "t333", + "children": null + }, + { + "label": "f3", + "name": "f3", + "value": "f3", + "children": null + }, + { + "label": "f1", + "name": "f1", + "value": "f1", + "children": null + }, + { + "label": "sdf", + "name": "sdf", + "value": "sdf", + "children": null + }, + { + "label": "sfsd", + "name": "sfsd", + "value": "sfsd", + "children": null + }, + { + "label": "hello", + "name": "hello", + "value": "hello", + "children": null + }, + { + "label": "k1", + "name": "k1", + "value": "k1", + "children": null + }, + { + "label": "f2", + "name": "f2", + "value": "f2", + "children": null + }, + { + "label": "test_f1", + "name": "test_f1", + "value": "test_f1", + "children": null + }, + { + "label": "SMMF", + "name": "SMMF", + "value": "SMMF", + "children": null + }, + { + "label": "docker_xxx", + "name": "docker_xxx", + "value": "docker_xxx", + "children": null + }, + { + "label": "t2", + "name": "t2", + "value": "t2", + "children": null + }, + { + "label": "t1", + "name": "t1", + "value": "t1", + "children": null + }, + { + "label": "test_graph", + "name": "test_graph", + "value": "test_graph", + "children": null + }, + { + "label": "small", + "name": "small", + "value": "small", + "children": null + }, + { + "label": "ttt", + "name": "ttt", + "value": "ttt", + "children": null + }, + { + "label": "bf", + "name": "bf", + "value": "bf", + "children": null + }, + { + "label": "new_big_file", + "name": "new_big_file", + "value": "new_big_file", + "children": null + }, + { + "label": "test_big_fild", + "name": "test_big_fild", + "value": "test_big_fild", + "children": null + }, + { + "label": "Greenplum", + "name": "Greenplum", + "value": "Greenplum", + "children": null + }, + { + "label": "Mytest", + "name": "Mytest", + "value": "Mytest", + "children": null + }, + { + "label": "dba", + "name": "dba", + "value": "dba", + "children": null + } + ] + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Context Key", + "name": "context", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "context", + "placeholder": null, + "description": "The key of the context, it will be used in building the prompt", + "value": null, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Top K", + "name": "top_k", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 5, + "placeholder": null, + "description": "The number of chunks to retrieve", + "value": null, + "options": null + }, + { + "type_name": "float", + "type_cls": "builtins.float", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Minimum Match Score", + "name": "score_threshold", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 0.3, + "placeholder": null, + "description": "The minimum match score for the retrieved chunks, it will be dropped if the match score is less than the threshold", + "value": null, + "options": null, + "ui": { + "refresh": false, + "refresh_depends": null, + "ui_type": "slider", + "size": null, + "attr": { + "disabled": false, + "min": 0.0, + "max": 1.0, + "step": 0.1 + }, + "show_input": false + } + }, + { + "type_name": "bool", + "type_cls": "builtins.bool", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Reranker Enabled", + "name": "reranker_enabled", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "Whether to enable the reranker", + "value": null, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Reranker Top K", + "name": "reranker_top_k", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 3, + "placeholder": null, + "description": "The top k for the reranker", + "value": null, + "options": null + } + ] + } + }, + { + "width": 320, + "height": 884, + "id": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0", + "position": { + "x": 195.5602050169747, + "y": 175.41495969060128, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": 195.5602050169747, + "y": 175.41495969060128, + "zoom": 0.0 + }, + "data": { + "type_name": "CommonChatPromptTemplate", + "type_cls": "dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate", + "label": "Common Chat Prompt Template", + "custom_label": null, + "name": "common_chat_prompt_template", + "description": "The operator to build the prompt with static prompt.", + "category": "prompt", + "category_label": "Prompt", + "flow_type": "resource", + "icon": null, + "documentation_url": null, + "id": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0", + "ui_size": "large" + }, + "resource_type": "instance", + "parent_cls": [ + "dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate", + "dbgpt.core.interface.prompt.ChatPromptTemplate", + "dbgpt.core.interface.prompt.BasePromptTemplate", + "pydantic.main.BaseModel" + ], + "parameters": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "System Message", + "name": "system_message", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "You are a helpful AI Assistant.", + "placeholder": null, + "description": "The system message.", + "value": "You are a helpful AI assistant.\nBased on the known information below, provide users with professional and concise answers to their questions.\nconstraints:\n 1.Ensure to include original markdown formatting elements such as images, links, tables, or code blocks without alteration in the response if they are present in the provided information.\n For example, image format should be ![image.png](xxx), link format [xxx](xxx), table format should be represented with |xxx|xxx|xxx|, and code format with xxx.\n 2.If the information available in the knowledge base is insufficient to answer the question, state clearly: \"The content provided in the knowledge base is not enough to answer this question,\" and avoid making up answers.\n 3.When responding, it is best to summarize the points in the order of 1, 2, 3, And displayed in markdwon format.\n\nknown information: \n{context}\n\nuser question:\n{user_input}\n\nwhen answering, use the same language as the \"user\".", + "options": null, + "ui": { + "refresh": false, + "refresh_depends": null, + "ui_type": "text_area", + "size": "large", + "attr": { + "disabled": false, + "status": null, + "prefix": null, + "suffix": null, + "show_count": null, + "max_length": null, + "auto_size": { + "min_rows": 2, + "max_rows": 20 + } + }, + "editor": { + "width": 800, + "height": 400 + } + } + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Message placeholder", + "name": "message_placeholder", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "chat_history", + "placeholder": null, + "description": "The chat history message placeholder.", + "value": null, + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Human Message", + "name": "human_message", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "{user_input}", + "placeholder": "{user_input}", + "description": "The human message.", + "value": null, + "options": null, + "ui": { + "refresh": false, + "refresh_depends": null, + "ui_type": "text_area", + "size": "large", + "attr": { + "disabled": false, + "status": null, + "prefix": null, + "suffix": null, + "show_count": null, + "max_length": null, + "auto_size": { + "min_rows": 2, + "max_rows": 20 + } + }, + "editor": { + "width": 800, + "height": 400 + } + } + } + ] + } + }, + { + "width": 320, + "height": 235, + "id": "operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0", + "position": { + "x": 1087.8490700167088, + "y": 389.9348086323575, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": 1087.8490700167088, + "y": 389.9348086323575, + "zoom": 0.0 + }, + "data": { + "label": "OpenAI Streaming Output Operator", + "custom_label": null, + "name": "openai_streaming_output_operator", + "description": "The OpenAI streaming LLM operator.", + "category": "output_parser", + "category_label": "Output Parser", + "flow_type": "operator", + "icon": null, + "documentation_url": null, + "id": "operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0" + }, + "operator_type": "transform_stream", + "inputs": [ + { + "type_name": "ModelOutput", + "type_cls": "dbgpt.core.interface.llm.ModelOutput", + "label": "Upstream Model Output", + "custom_label": null, + "name": "model_output", + "description": "The model output of upstream.", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": true, + "mappers": null + } + ], + "outputs": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "label": "Model Output", + "custom_label": null, + "name": "model_output", + "description": "The model output after transformed to openai stream format.", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": true, + "mappers": null + } + ], + "version": "v1", + "type_name": "OpenAIStreamingOutputOperator", + "type_cls": "dbgpt.model.utils.chatgpt_utils.OpenAIStreamingOutputOperator", + "parameters": [] + } + } + ] + } + } +} \ No newline at end of file From c67b50052d50f6f302902a4f7a67725303f52c3a Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 30 Aug 2024 15:00:14 +0800 Subject: [PATCH 08/60] chore: Merge latest code --- dbgpt/app/component_configs.py | 12 + dbgpt/app/operators/__init__.py | 4 + dbgpt/app/operators/converter.py | 186 ++++++++ dbgpt/app/operators/datasource.py | 336 +++++++++++++ dbgpt/app/operators/llm.py | 443 ++++++++++++++++++ dbgpt/app/operators/rag.py | 191 ++++++++ dbgpt/core/awel/flow/__init__.py | 2 + dbgpt/core/awel/flow/base.py | 116 ++++- dbgpt/core/awel/flow/flow_factory.py | 42 +- dbgpt/core/awel/flow/ui.py | 39 +- dbgpt/core/awel/trigger/http_trigger.py | 3 + dbgpt/core/interface/llm.py | 3 + dbgpt/core/interface/message.py | 27 +- .../core/interface/operators/llm_operator.py | 33 +- .../interface/operators/prompt_operator.py | 31 +- dbgpt/core/interface/output_parser.py | 89 +++- dbgpt/core/interface/prompt.py | 12 + dbgpt/model/cluster/client.py | 4 +- dbgpt/model/operators/llm_operator.py | 27 +- dbgpt/model/utils/chatgpt_utils.py | 9 +- dbgpt/rag/summary/db_summary_client.py | 3 +- dbgpt/serve/agent/resource/datasource.py | 67 ++- dbgpt/serve/agent/resource/knowledge.py | 1 + dbgpt/serve/flow/api/endpoints.py | 21 +- dbgpt/serve/flow/service/service.py | 8 +- dbgpt/serve/rag/operators/knowledge_space.py | 2 +- 26 files changed, 1643 insertions(+), 68 deletions(-) create mode 100644 dbgpt/app/operators/__init__.py create mode 100644 dbgpt/app/operators/converter.py create mode 100644 dbgpt/app/operators/datasource.py create mode 100644 dbgpt/app/operators/llm.py create mode 100644 dbgpt/app/operators/rag.py diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index a8a0f24d1..418d9eae1 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -60,6 +60,7 @@ def initialize_components( _initialize_openapi(system_app) # Register serve apps register_serve_apps(system_app, CFG, param.port) + _initialize_operators() def _initialize_model_cache(system_app: SystemApp, port: int): @@ -128,3 +129,14 @@ def _initialize_openapi(system_app: SystemApp): from dbgpt.app.openapi.api_v1.editor.service import EditorService system_app.register(EditorService) + + +def _initialize_operators(): + from dbgpt.app.operators.converter import StringToInteger + from dbgpt.app.operators.datasource import ( + HODatasourceExecutorOperator, + HODatasourceRetrieverOperator, + ) + from dbgpt.app.operators.llm import HOLLMOperator, HOStreamingLLMOperator + from dbgpt.app.operators.rag import HOKnowledgeOperator + from dbgpt.serve.agent.resource.datasource import DatasourceResource diff --git a/dbgpt/app/operators/__init__.py b/dbgpt/app/operators/__init__.py new file mode 100644 index 000000000..353336a34 --- /dev/null +++ b/dbgpt/app/operators/__init__.py @@ -0,0 +1,4 @@ +"""Operators package. + +This package contains all higher-order operators that are used to build workflows. +""" diff --git a/dbgpt/app/operators/converter.py b/dbgpt/app/operators/converter.py new file mode 100644 index 000000000..1115e0de4 --- /dev/null +++ b/dbgpt/app/operators/converter.py @@ -0,0 +1,186 @@ +"""Type Converter Operators.""" + +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + Parameter, + ViewMetadata, +) +from dbgpt.util.i18n_utils import _ + +_INPUTS_STRING = IOField.build_from( + _("String"), + "string", + str, + description=_("The string to be converted to other types."), +) +_INPUTS_INTEGER = IOField.build_from( + _("Integer"), + "integer", + int, + description=_("The integer to be converted to other types."), +) +_INPUTS_FLOAT = IOField.build_from( + _("Float"), + "float", + float, + description=_("The float to be converted to other types."), +) +_INPUTS_BOOLEAN = IOField.build_from( + _("Boolean"), + "boolean", + bool, + description=_("The boolean to be converted to other types."), +) + +_OUTPUTS_STRING = IOField.build_from( + _("String"), + "string", + str, + description=_("The string converted from other types."), +) +_OUTPUTS_INTEGER = IOField.build_from( + _("Integer"), + "integer", + int, + description=_("The integer converted from other types."), +) +_OUTPUTS_FLOAT = IOField.build_from( + _("Float"), + "float", + float, + description=_("The float converted from other types."), +) +_OUTPUTS_BOOLEAN = IOField.build_from( + _("Boolean"), + "boolean", + bool, + description=_("The boolean converted from other types."), +) + + +class StringToInteger(MapOperator[str, int]): + """Converts a string to an integer.""" + + metadata = ViewMetadata( + label=_("String to Integer"), + name="default_converter_string_to_integer", + description=_("Converts a string to an integer."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_STRING], + outputs=[_OUTPUTS_INTEGER], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new StringToInteger operator.""" + super().__init__(map_function=lambda x: int(x), **kwargs) + + +class StringToFloat(MapOperator[str, float]): + """Converts a string to a float.""" + + metadata = ViewMetadata( + label=_("String to Float"), + name="default_converter_string_to_float", + description=_("Converts a string to a float."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_STRING], + outputs=[_OUTPUTS_FLOAT], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new StringToFloat operator.""" + super().__init__(map_function=lambda x: float(x), **kwargs) + + +class StringToBoolean(MapOperator[str, bool]): + """Converts a string to a boolean.""" + + metadata = ViewMetadata( + label=_("String to Boolean"), + name="default_converter_string_to_boolean", + description=_("Converts a string to a boolean, true: 'true', '1', 'y'"), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[ + Parameter.build_from( + _("True Values"), + "true_values", + str, + optional=True, + default="true,1,y", + description=_("Comma-separated values that should be treated as True."), + ) + ], + inputs=[_INPUTS_STRING], + outputs=[_OUTPUTS_BOOLEAN], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, true_values: str = "true,1,y", **kwargs): + """Create a new StringToBoolean operator.""" + true_values_list = true_values.split(",") + true_values_list = [x.strip().lower() for x in true_values_list] + super().__init__(map_function=lambda x: x.lower() in true_values_list, **kwargs) + + +class IntegerToString(MapOperator[int, str]): + """Converts an integer to a string.""" + + metadata = ViewMetadata( + label=_("Integer to String"), + name="default_converter_integer_to_string", + description=_("Converts an integer to a string."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_INTEGER], + outputs=[_OUTPUTS_STRING], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new IntegerToString operator.""" + super().__init__(map_function=lambda x: str(x), **kwargs) + + +class FloatToString(MapOperator[float, str]): + """Converts a float to a string.""" + + metadata = ViewMetadata( + label=_("Float to String"), + name="default_converter_float_to_string", + description=_("Converts a float to a string."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_FLOAT], + outputs=[_OUTPUTS_STRING], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new FloatToString operator.""" + super().__init__(map_function=lambda x: str(x), **kwargs) + + +class BooleanToString(MapOperator[bool, str]): + """Converts a boolean to a string.""" + + metadata = ViewMetadata( + label=_("Boolean to String"), + name="default_converter_boolean_to_string", + description=_("Converts a boolean to a string."), + category=OperatorCategory.TYPE_CONVERTER, + parameters=[], + inputs=[_INPUTS_BOOLEAN], + outputs=[_OUTPUTS_STRING], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + """Create a new BooleanToString operator.""" + super().__init__(map_function=lambda x: str(x), **kwargs) diff --git a/dbgpt/app/operators/datasource.py b/dbgpt/app/operators/datasource.py new file mode 100644 index 000000000..7fe16feaa --- /dev/null +++ b/dbgpt/app/operators/datasource.py @@ -0,0 +1,336 @@ +import json +import logging +from typing import List, Optional + +from dbgpt._private.config import Config +from dbgpt.agent.resource.database import DBResource +from dbgpt.core.awel import DAGContext, MapOperator +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + Parameter, + ViewMetadata, + ui, +) +from dbgpt.core.operators import BaseLLM +from dbgpt.util.i18n_utils import _ +from dbgpt.vis.tags.vis_chart import default_chart_type_prompt + +from .llm import HOContextBody + +logger = logging.getLogger(__name__) + +CFG = Config() + +_DEFAULT_CHART_TYPE = default_chart_type_prompt() + +_DEFAULT_TEMPLATE_EN = """You are a database expert. +Please answer the user's question based on the database selected by the user and some \ +of the available table structure definitions of the database. +Database name: + {db_name} +Table structure definition: + {table_info} + +Constraint: + 1.Please understand the user's intention based on the user's question, and use the \ + given table structure definition to create a grammatically correct {dialect} sql. \ + If sql is not required, answer the user's question directly.. + 2.Always limit the query to a maximum of {max_num_results} results unless the user \ + specifies in the question the specific number of rows of data he wishes to obtain. + 3.You can only use the tables provided in the table structure information to \ + generate sql. If you cannot generate sql based on the provided table structure, \ + please say: "The table structure information provided is not enough to generate \ + sql queries." It is prohibited to fabricate information at will. + 4.Please be careful not to mistake the relationship between tables and columns \ + when generating SQL. + 5.Please check the correctness of the SQL and ensure that the query performance is \ + optimized under correct conditions. + 6.Please choose the best one from the display methods given below for data \ + rendering, and put the type name into the name parameter value that returns the \ + required format. If you cannot find the most suitable one, use 'Table' as the \ + display method. , the available data display methods are as follows: {display_type} + +User Question: + {user_input} +Please think step by step and respond according to the following JSON format: + {response} +Ensure the response is correct json and can be parsed by Python json.loads. +""" + +_DEFAULT_TEMPLATE_ZH = """你是一个数据库专家. +请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题. +数据库名: + {db_name} +表结构定义: + {table_info} + +约束: + 1. 请根据用户问题理解用户意图,使用给出表结构定义创建一个语法正确的 {dialect} sql,如果不需要 \ + sql,则直接回答用户问题。 + 2. 除非用户在问题中指定了他希望获得的具体数据行数,否则始终将查询限制为最多 {max_num_results} \ + 个结果。 + 3. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:\ + “提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 + 4. 请注意生成SQL时不要弄错表和列的关系 + 5. 请检查SQL的正确性,并保证正确的情况下优化查询性能 + 6.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种\ + ,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {display_type} +用户问题: + {user_input} +请一步步思考并按照以下JSON格式回复: + {response} +确保返回正确的json并且可以被Python json.loads方法解析. +""" +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) + +_DEFAULT_RESPONSE = json.dumps( + { + "thoughts": "thoughts summary to say to user", + "sql": "SQL Query to run", + "display_type": "Data display method", + }, + ensure_ascii=False, + indent=4, +) + +_PARAMETER_DATASOURCE = Parameter.build_from( + _("Datasource"), + "datasource", + type=DBResource, + description=_("The datasource to retrieve the context"), +) +_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from( + _("Prompt Template"), + "prompt_template", + type=str, + optional=True, + default=_DEFAULT_TEMPLATE, + description=_("The prompt template to build a database prompt"), + ui=ui.DefaultUITextArea(), +) +_PARAMETER_DISPLAY_TYPE = Parameter.build_from( + _("Display Type"), + "display_type", + type=str, + optional=True, + default=_DEFAULT_CHART_TYPE, + description=_("The display type for the data"), + ui=ui.DefaultUITextArea(), +) +_PARAMETER_MAX_NUM_RESULTS = Parameter.build_from( + _("Max Number of Results"), + "max_num_results", + type=int, + optional=True, + default=50, + description=_("The maximum number of results to return"), +) +_PARAMETER_RESPONSE_FORMAT = Parameter.build_from( + _("Response Format"), + "response_format", + type=str, + optional=True, + default=_DEFAULT_RESPONSE, + description=_("The response format, default is a JSON format"), + ui=ui.DefaultUITextArea(), +) + +_PARAMETER_CONTEXT_KEY = Parameter.build_from( + _("Context Key"), + "context_key", + type=str, + optional=True, + default="context", + description=_("The key of the context, it will be used in building the prompt"), +) +_INPUTS_QUESTION = IOField.build_from( + _("User question"), + "query", + str, + description=_("The user question to retrieve table schemas from the datasource"), +) +_OUTPUTS_CONTEXT = IOField.build_from( + _("Retrieved context"), + "context", + HOContextBody, + description=_("The retrieved context from the datasource"), +) + +_INPUTS_SQL_DICT = IOField.build_from( + _("SQL dict"), + "sql_dict", + dict, + description=_("The SQL to be executed wrapped in a dictionary, generated by LLM"), +) +_OUTPUTS_SQL_RESULT = IOField.build_from( + _("SQL result"), + "sql_result", + str, + description=_("The result of the SQL execution"), +) + +_INPUTS_SQL_DICT_LIST = IOField.build_from( + _("SQL dict list"), + "sql_dict_list", + dict, + description=_( + "The SQL list to be executed wrapped in a dictionary, generated by LLM" + ), + is_list=True, +) + + +class GPTVisMixin: + async def save_view_message(self, dag_ctx: DAGContext, view: str): + """Save the view message.""" + await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view) + + +class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]): + """Retrieve the table schemas from the datasource.""" + + metadata = ViewMetadata( + label=_("Datasource Retriever Operator"), + name="higher_order_datasource_retriever_operator", + description=_("Retrieve the table schemas from the datasource."), + category=OperatorCategory.DATABASE, + parameters=[ + _PARAMETER_DATASOURCE.new(), + _PARAMETER_PROMPT_TEMPLATE.new(), + _PARAMETER_DISPLAY_TYPE.new(), + _PARAMETER_MAX_NUM_RESULTS.new(), + _PARAMETER_RESPONSE_FORMAT.new(), + _PARAMETER_CONTEXT_KEY.new(), + ], + inputs=[_INPUTS_QUESTION.new()], + outputs=[_OUTPUTS_CONTEXT.new()], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__( + self, + datasource: DBResource, + prompt_template: str = _DEFAULT_TEMPLATE, + display_type: str = _DEFAULT_CHART_TYPE, + max_num_results: int = 50, + response_format: str = _DEFAULT_RESPONSE, + context_key: Optional[str] = "context", + **kwargs, + ): + """Initialize the operator.""" + super().__init__(**kwargs) + self._datasource = datasource + self._prompt_template = prompt_template + self._display_type = display_type + self._max_num_results = max_num_results + self._response_format = response_format + self._context_key = context_key + + async def map(self, question: str) -> HOContextBody: + """Retrieve the context from the datasource.""" + db_name = self._datasource._db_name + dialect = self._datasource.dialect + schema_info = await self.blocking_func_to_async( + self._datasource.get_schema_link, + db=db_name, + question=question, + ) + context = self._prompt_template.format( + db_name=db_name, + table_info=schema_info, + dialect=dialect, + max_num_results=self._max_num_results, + display_type=self._display_type, + user_input=question, + response=self._response_format, + ) + + return HOContextBody( + context_key=self._context_key, + context=context, + ) + + +class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]): + """Execute the context from the datasource.""" + + metadata = ViewMetadata( + label=_("Datasource Executor Operator"), + name="higher_order_datasource_executor_operator", + description=_("Execute the context from the datasource."), + category=OperatorCategory.DATABASE, + parameters=[_PARAMETER_DATASOURCE.new()], + inputs=[_INPUTS_SQL_DICT.new()], + outputs=[_OUTPUTS_SQL_RESULT.new()], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, datasource: DBResource, **kwargs): + """Initialize the operator.""" + MapOperator.__init__(self, **kwargs) + self._datasource = datasource + + async def map(self, sql_dict: dict) -> str: + """Execute the context from the datasource.""" + from dbgpt.vis.tags.vis_chart import VisChart + + if not isinstance(sql_dict, dict): + raise ValueError( + "The input value of datasource executor should be a dictionary." + ) + vis = VisChart() + sql = sql_dict.get("sql") + if not sql: + return sql_dict.get("thoughts", "No SQL found in the input dictionary.") + data_df = await self._datasource.query_to_df(sql) + view = await vis.display(chart=sql_dict, data_df=data_df) + await self.save_view_message(self.current_dag_context, view) + return view + + +class HODatasourceDashboardOperator(GPTVisMixin, MapOperator[dict, str]): + """Execute the context from the datasource.""" + + metadata = ViewMetadata( + label=_("Datasource Dashboard Operator"), + name="higher_order_datasource_dashboard_operator", + description=_("Execute the context from the datasource."), + category=OperatorCategory.DATABASE, + parameters=[_PARAMETER_DATASOURCE.new()], + inputs=[_INPUTS_SQL_DICT_LIST.new()], + outputs=[_OUTPUTS_SQL_RESULT.new()], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, datasource: DBResource, **kwargs): + """Initialize the operator.""" + MapOperator.__init__(self, **kwargs) + self._datasource = datasource + + async def map(self, sql_dict_list: List[dict]) -> str: + """Execute the context from the datasource.""" + from dbgpt.vis.tags.vis_dashboard import VisDashboard + + if not isinstance(sql_dict_list, list): + raise ValueError( + "The input value of datasource executor should be a list of dictionaries." + ) + vis = VisDashboard() + chart_params = [] + for chart_item in sql_dict_list: + chart_dict = {k: v for k, v in chart_item.items()} + sql = chart_item.get("sql") + try: + data_df = await self._datasource.query_to_df(sql) + chart_dict["data"] = data_df + except Exception as e: + logger.warning(f"Sql execute failed!{str(e)}") + chart_dict["err_msg"] = str(e) + chart_params.append(chart_dict) + view = await vis.display(charts=chart_params) + await self.save_view_message(self.current_dag_context, view) + return view diff --git a/dbgpt/app/operators/llm.py b/dbgpt/app/operators/llm.py new file mode 100644 index 000000000..56b67a010 --- /dev/null +++ b/dbgpt/app/operators/llm.py @@ -0,0 +1,443 @@ +from typing import List, Literal, Optional, Tuple, Union + +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.core import ( + BaseMessage, + ChatPromptTemplate, + LLMClient, + ModelOutput, + ModelRequest, + StorageConversation, +) +from dbgpt.core.awel import ( + DAG, + BaseOperator, + CommonLLMHttpRequestBody, + DAGContext, + DefaultInputContext, + InputOperator, + JoinOperator, + MapOperator, + SimpleCallDataInputSource, + TaskOutput, +) +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + OptionValue, + Parameter, + ViewMetadata, + ui, +) +from dbgpt.core.interface.operators.message_operator import ( + BaseConversationOperator, + BufferedConversationMapperOperator, + TokenBufferedConversationMapperOperator, +) +from dbgpt.core.interface.operators.prompt_operator import HistoryPromptBuilderOperator +from dbgpt.model.operators import LLMOperator, StreamingLLMOperator +from dbgpt.serve.conversation.serve import Serve as ConversationServe +from dbgpt.util.i18n_utils import _ +from dbgpt.util.tracer import root_tracer + + +class HOContextBody(BaseModel): + """Higher-order context body.""" + + context_key: str = Field( + "context", + description=_("The context key can be used as the key for formatting prompt."), + ) + context: Union[str, List[str]] = Field( + ..., + description=_("The context."), + ) + + +class BaseHOLLMOperator( + BaseConversationOperator, + JoinOperator[ModelRequest], + LLMOperator, + StreamingLLMOperator, +): + """Higher-order model request builder operator.""" + + def __init__( + self, + prompt_template: ChatPromptTemplate, + model: str = None, + llm_client: Optional[LLMClient] = None, + history_merge_mode: Literal["none", "window", "token"] = "window", + user_message_key: str = "user_input", + history_key: Optional[str] = None, + keep_start_rounds: Optional[int] = None, + keep_end_rounds: Optional[int] = None, + max_token_limit: int = 2048, + **kwargs, + ): + JoinOperator.__init__(self, combine_function=self._join_func, **kwargs) + LLMOperator.__init__(self, llm_client=llm_client, **kwargs) + StreamingLLMOperator.__init__(self, llm_client=llm_client, **kwargs) + + # User must select a history merge mode + self._history_merge_mode = history_merge_mode + self._user_message_key = user_message_key + self._has_history = history_merge_mode != "none" + self._prompt_template = prompt_template + self._model = model + self._history_key = history_key + self._str_history = False + self._keep_start_rounds = keep_start_rounds if self._has_history else 0 + self._keep_end_rounds = keep_end_rounds if self._has_history else 0 + self._max_token_limit = max_token_limit + self._sub_compose_dag = self._build_conversation_composer_dag() + + async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]: + conv_serve = ConversationServe.get_instance(self.system_app) + self._storage = conv_serve.conv_storage + self._message_storage = conv_serve.message_storage + + _: TaskOutput[ModelRequest] = await JoinOperator._do_run(self, dag_ctx) + dag_ctx.current_task_context.set_task_input( + DefaultInputContext([dag_ctx.current_task_context]) + ) + if dag_ctx.streaming_call: + task_output = await StreamingLLMOperator._do_run(self, dag_ctx) + else: + task_output = await LLMOperator._do_run(self, dag_ctx) + + return task_output + + async def after_dag_end(self, event_loop_task_id: int): + model_output: Optional[ + ModelOutput + ] = await self.current_dag_context.get_from_share_data( + LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT + ) + model_output_view: Optional[ + str + ] = await self.current_dag_context.get_from_share_data( + LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW + ) + storage_conv = await self.get_storage_conversation() + end_current_round: bool = False + if model_output and storage_conv: + # Save model output message to storage + storage_conv.add_ai_message(model_output.text) + end_current_round = True + if model_output_view and storage_conv: + # Save model output view to storage + storage_conv.add_view_message(model_output_view) + end_current_round = True + if end_current_round: + # End current conversation round and flush to storage + storage_conv.end_current_round() + + async def _join_func(self, req: CommonLLMHttpRequestBody, *args): + dynamic_inputs = [] + for arg in args: + if isinstance(arg, HOContextBody): + dynamic_inputs.append(arg) + # Load and store chat history, default use InMemoryStorage. + storage_conv, history_messages = await self.blocking_func_to_async( + self._build_storage, req + ) + # Save the storage conversation to share data, for the child operators + await self.current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv + ) + + user_input = ( + req.messages[-1] if isinstance(req.messages, list) else req.messages + ) + prompt_dict = { + self._user_message_key: user_input, + } + for dynamic_input in dynamic_inputs: + if dynamic_input.context_key in prompt_dict: + raise ValueError( + f"Duplicate context key '{dynamic_input.context_key}' in upstream " + f"operators." + ) + prompt_dict[dynamic_input.context_key] = dynamic_input.context + + call_data = { + "messages": history_messages, + "prompt_dict": prompt_dict, + } + end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0] + # Sub dag, use the same dag context in the parent dag + messages = await end_node.call(call_data, dag_ctx=self.current_dag_context) + model_request = ModelRequest.build_request( + model=req.model, + messages=messages, + context=req.context, + temperature=req.temperature, + max_new_tokens=req.max_new_tokens, + span_id=root_tracer.get_current_span_id(), + echo=False, + ) + if storage_conv: + # Start new round + storage_conv.start_new_round() + storage_conv.add_user_message(user_input) + return model_request + + def _build_storage( + self, req: CommonLLMHttpRequestBody + ) -> Tuple[StorageConversation, List[BaseMessage]]: + # Create a new storage conversation, this will load the conversation from + # storage, so we must do this async + storage_conv: StorageConversation = StorageConversation( + conv_uid=req.conv_uid, + chat_mode=req.chat_mode, + user_name=req.user_name, + sys_code=req.sys_code, + conv_storage=self.storage, + message_storage=self.message_storage, + param_type="", + param_value=req.chat_param, + ) + # Get history messages from storage + history_messages: List[BaseMessage] = storage_conv.get_history_message( + include_system_message=False + ) + + return storage_conv, history_messages + + def _build_conversation_composer_dag(self) -> DAG: + with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag: + input_task = InputOperator(input_source=SimpleCallDataInputSource()) + # History transform task + if self._history_merge_mode == "token": + history_transform_task = TokenBufferedConversationMapperOperator( + model=self._model, + llm_client=self.llm_client, + max_token_limit=self._max_token_limit, + ) + else: + history_transform_task = BufferedConversationMapperOperator( + keep_start_rounds=self._keep_start_rounds, + keep_end_rounds=self._keep_end_rounds, + ) + if self._history_key: + history_key = self._history_key + else: + placeholders = self._prompt_template.get_placeholders() + if not placeholders or len(placeholders) != 1: + raise ValueError( + "The prompt template must have exactly one placeholder if " + "history_key is not provided." + ) + history_key = placeholders[0] + history_prompt_build_task = HistoryPromptBuilderOperator( + prompt=self._prompt_template, + history_key=history_key, + check_storage=False, + save_to_storage=False, + str_history=self._str_history, + ) + # Build composer dag + ( + input_task + >> MapOperator(lambda x: x["messages"]) + >> history_transform_task + >> history_prompt_build_task + ) + ( + input_task + >> MapOperator(lambda x: x["prompt_dict"]) + >> history_prompt_build_task + ) + + return composer_dag + + +_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from( + _("Prompt Template"), + "prompt_template", + ChatPromptTemplate, + description=_("The prompt template for the conversation."), +) +_PARAMETER_MODEL = Parameter.build_from( + _("Model Name"), + "model", + str, + optional=True, + default=None, + description=_("The model name."), +) + +_PARAMETER_LLM_CLIENT = Parameter.build_from( + _("LLM Client"), + "llm_client", + LLMClient, + optional=True, + default=None, + description=_( + "The LLM Client, how to connect to the LLM model, if not provided, it will use" + " the default client deployed by DB-GPT." + ), +) +_PARAMETER_HISTORY_MERGE_MODE = Parameter.build_from( + _("History Message Merge Mode"), + "history_merge_mode", + str, + optional=True, + default="none", + options=[ + OptionValue(label="No History", name="none", value="none"), + OptionValue(label="Message Window", name="window", value="window"), + OptionValue(label="Token Length", name="token", value="token"), + ], + description=_( + "The history merge mode, supports 'none', 'window' and 'token'." + " 'none': no history merge, 'window': merge by conversation window, 'token': " + "merge by token length." + ), + ui=ui.UISelect(), +) +_PARAMETER_USER_MESSAGE_KEY = Parameter.build_from( + _("User Message Key"), + "user_message_key", + str, + optional=True, + default="user_input", + description=_( + "The key of the user message in your prompt, default is 'user_input'." + ), +) +_PARAMETER_HISTORY_KEY = Parameter.build_from( + _("History Key"), + "history_key", + str, + optional=True, + default=None, + description=_( + "The chat history key, with chat history message pass to prompt template, " + "if not provided, it will parse the prompt template to get the key." + ), +) +_PARAMETER_KEEP_START_ROUNDS = Parameter.build_from( + _("Keep Start Rounds"), + "keep_start_rounds", + int, + optional=True, + default=None, + description=_("The start rounds to keep in the chat history."), +) +_PARAMETER_KEEP_END_ROUNDS = Parameter.build_from( + _("Keep End Rounds"), + "keep_end_rounds", + int, + optional=True, + default=None, + description=_("The end rounds to keep in the chat history."), +) +_PARAMETER_MAX_TOKEN_LIMIT = Parameter.build_from( + _("Max Token Limit"), + "max_token_limit", + int, + optional=True, + default=2048, + description=_("The max token limit to keep in the chat history."), +) + +_INPUTS_COMMON_LLM_REQUEST_BODY = IOField.build_from( + _("Common LLM Request Body"), + "common_llm_request_body", + CommonLLMHttpRequestBody, + _("The common LLM request body."), +) +_INPUTS_EXTRA_CONTEXT = IOField.build_from( + _("Extra Context"), + "extra_context", + HOContextBody, + _( + "Extra context for building prompt(Knowledge context, database " + "schema, etc), you can add multiple context." + ), + dynamic=True, +) +_OUTPUTS_MODEL_OUTPUT = IOField.build_from( + _("Model Output"), + "model_output", + ModelOutput, + description=_("The model output."), +) +_OUTPUTS_STREAMING_MODEL_OUTPUT = IOField.build_from( + _("Streaming Model Output"), + "streaming_model_output", + ModelOutput, + is_list=True, + description=_("The streaming model output."), +) + + +class HOLLMOperator(BaseHOLLMOperator): + metadata = ViewMetadata( + label=_("LLM Operator"), + name="higher_order_llm_operator", + category=OperatorCategory.LLM, + description=_( + "High-level LLM operator, supports multi-round conversation " + "(conversation window, token length and no multi-round)." + ), + parameters=[ + _PARAMETER_PROMPT_TEMPLATE.new(), + _PARAMETER_MODEL.new(), + _PARAMETER_LLM_CLIENT.new(), + _PARAMETER_HISTORY_MERGE_MODE.new(), + _PARAMETER_USER_MESSAGE_KEY.new(), + _PARAMETER_HISTORY_KEY.new(), + _PARAMETER_KEEP_START_ROUNDS.new(), + _PARAMETER_KEEP_END_ROUNDS.new(), + _PARAMETER_MAX_TOKEN_LIMIT.new(), + ], + inputs=[ + _INPUTS_COMMON_LLM_REQUEST_BODY.new(), + _INPUTS_EXTRA_CONTEXT.new(), + ], + outputs=[ + _OUTPUTS_MODEL_OUTPUT.new(), + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class HOStreamingLLMOperator(BaseHOLLMOperator): + metadata = ViewMetadata( + label=_("Streaming LLM Operator"), + name="higher_order_streaming_llm_operator", + category=OperatorCategory.LLM, + description=_( + "High-level streaming LLM operator, supports multi-round conversation " + "(conversation window, token length and no multi-round)." + ), + parameters=[ + _PARAMETER_PROMPT_TEMPLATE.new(), + _PARAMETER_MODEL.new(), + _PARAMETER_LLM_CLIENT.new(), + _PARAMETER_HISTORY_MERGE_MODE.new(), + _PARAMETER_USER_MESSAGE_KEY.new(), + _PARAMETER_HISTORY_KEY.new(), + _PARAMETER_KEEP_START_ROUNDS.new(), + _PARAMETER_KEEP_END_ROUNDS.new(), + _PARAMETER_MAX_TOKEN_LIMIT.new(), + ], + inputs=[ + _INPUTS_COMMON_LLM_REQUEST_BODY.new(), + _INPUTS_EXTRA_CONTEXT.new(), + ], + outputs=[ + _OUTPUTS_STREAMING_MODEL_OUTPUT.new(), + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) diff --git a/dbgpt/app/operators/rag.py b/dbgpt/app/operators/rag.py new file mode 100644 index 000000000..79d166ac0 --- /dev/null +++ b/dbgpt/app/operators/rag.py @@ -0,0 +1,191 @@ +from typing import List, Optional + +from dbgpt._private.config import Config +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + FunctionDynamicOptions, + IOField, + OperatorCategory, + OptionValue, + Parameter, + ViewMetadata, + ui, +) +from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever +from dbgpt.util.i18n_utils import _ + +from .llm import HOContextBody + +CFG = Config() + + +def _load_space_name() -> List[OptionValue]: + from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity + + spaces = KnowledgeSpaceDao().get_knowledge_space(KnowledgeSpaceEntity()) + return [ + OptionValue(label=space.name, name=space.name, value=space.name) + for space in spaces + ] + + +_PARAMETER_CONTEXT_KEY = Parameter.build_from( + _("Context Key"), + "context", + type=str, + optional=True, + default="context", + description=_("The key of the context, it will be used in building the prompt"), +) +_PARAMETER_TOP_K = Parameter.build_from( + _("Top K"), + "top_k", + type=int, + optional=True, + default=5, + description=_("The number of chunks to retrieve"), +) +_PARAMETER_SCORE_THRESHOLD = Parameter.build_from( + _("Minimum Match Score"), + "score_threshold", + type=float, + optional=True, + default=0.3, + description=_( + _( + "The minimum match score for the retrieved chunks, it will be dropped if " + "the match score is less than the threshold" + ) + ), + ui=ui.UISlider(attr=ui.UISlider.UIAttribute(min=0.0, max=1.0, step=0.1)), +) + +_PARAMETER_RE_RANKER_ENABLED = Parameter.build_from( + _("Reranker Enabled"), + "reranker_enabled", + type=bool, + optional=True, + default=None, + description=_("Whether to enable the reranker"), +) +_PARAMETER_RE_RANKER_TOP_K = Parameter.build_from( + _("Reranker Top K"), + "reranker_top_k", + type=int, + optional=True, + default=3, + description=_("The top k for the reranker"), +) + +_INPUTS_QUESTION = IOField.build_from( + _("User question"), + "query", + str, + description=_("The user question to retrieve the knowledge"), +) +_OUTPUTS_CONTEXT = IOField.build_from( + _("Retrieved context"), + "context", + HOContextBody, + description=_("The retrieved context from the knowledge space"), +) + + +class HOKnowledgeOperator(MapOperator[str, HOContextBody]): + metadata = ViewMetadata( + label=_("Knowledge Operator"), + name="higher_order_knowledge_operator", + category=OperatorCategory.RAG, + description=_( + _( + "Knowledge Operator, retrieve your knowledge(documents) from knowledge" + " space" + ) + ), + parameters=[ + Parameter.build_from( + _("Knowledge Space Name"), + "knowledge_space", + type=str, + options=FunctionDynamicOptions(func=_load_space_name), + description=_("The name of the knowledge space"), + ), + _PARAMETER_CONTEXT_KEY.new(), + _PARAMETER_TOP_K.new(), + _PARAMETER_SCORE_THRESHOLD.new(), + _PARAMETER_RE_RANKER_ENABLED.new(), + _PARAMETER_RE_RANKER_TOP_K.new(), + ], + inputs=[ + _INPUTS_QUESTION.new(), + ], + outputs=[ + _OUTPUTS_CONTEXT.new(), + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__( + self, + knowledge_space: str, + context_key: Optional[str] = "context", + top_k: Optional[int] = None, + score_threshold: Optional[float] = None, + reranker_enabled: Optional[bool] = None, + reranker_top_k: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + self._knowledge_space = knowledge_space + self._context_key = context_key + self._top_k = top_k + self._score_threshold = score_threshold + self._reranker_enabled = reranker_enabled + self._reranker_top_k = reranker_top_k + + from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory + from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker + from dbgpt.serve.rag.models.models import ( + KnowledgeSpaceDao, + KnowledgeSpaceEntity, + ) + + spaces = KnowledgeSpaceDao().get_knowledge_space( + KnowledgeSpaceEntity(name=knowledge_space) + ) + if len(spaces) != 1: + raise Exception(f"invalid space name: {knowledge_space}") + space = spaces[0] + + reranker: Optional[RerankEmbeddingsRanker] = None + + if CFG.RERANK_MODEL and self._reranker_enabled: + reranker_top_k = ( + self._reranker_top_k + if self._reranker_top_k is not None + else CFG.RERANK_TOP_K + ) + rerank_embeddings = RerankEmbeddingFactory.get_instance( + CFG.SYSTEM_APP + ).create() + reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=reranker_top_k) + if self._top_k < reranker_top_k or self._top_k < 20: + # We use reranker, so if the top_k is less than 20, + # we need to set it to 20 + self._top_k = max(reranker_top_k, 20) + + self._space_retriever = KnowledgeSpaceRetriever( + space_id=space.id, + top_k=self._top_k, + rerank=reranker, + ) + + async def map(self, query: str) -> HOContextBody: + chunks = await self._space_retriever.aretrieve_with_scores( + query, self._score_threshold + ) + return HOContextBody( + context_key=self._context_key, + context=[chunk.content for chunk in chunks], + ) diff --git a/dbgpt/core/awel/flow/__init__.py b/dbgpt/core/awel/flow/__init__.py index 80db5b7e6..0d4e268c2 100644 --- a/dbgpt/core/awel/flow/__init__.py +++ b/dbgpt/core/awel/flow/__init__.py @@ -10,6 +10,7 @@ VariablesDynamicOptions, ) from .base import ( # noqa: F401 + TAGS_ORDER_HIGH, IOField, OperatorCategory, OperatorType, @@ -33,6 +34,7 @@ "ResourceCategory", "ResourceType", "OperatorType", + "TAGS_ORDER_HIGH", "IOField", "BaseDynamicOptions", "FunctionDynamicOptions", diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 314cb2171..db8bbcb84 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -40,6 +40,9 @@ T = TypeVar("T", bound="ViewMixin") TM = TypeVar("TM", bound="TypeMetadata") +TAGS_ORDER_HIGH = "higher-order" +TAGS_ORDER_FIRST = "first-order" + def _get_type_name(type_: Type[Any]) -> str: """Get the type name of the type. @@ -143,6 +146,8 @@ def __init__(self, label: str, description: str): "agent": _CategoryDetail("Agent", "The agent operator"), "rag": _CategoryDetail("RAG", "The RAG operator"), "experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"), + "database": _CategoryDetail("Database", "Interact with the database"), + "type_converter": _CategoryDetail("Type Converter", "Convert the type"), "example": _CategoryDetail("Example", "Example operator"), } @@ -159,6 +164,8 @@ class OperatorCategory(str, Enum): AGENT = "agent" RAG = "rag" EXPERIMENTAL = "experimental" + DATABASE = "database" + TYPE_CONVERTER = "type_converter" EXAMPLE = "example" def label(self) -> str: @@ -202,6 +209,7 @@ class OperatorType(str, Enum): "embeddings": _CategoryDetail("Embeddings", "The embeddings resource"), "rag": _CategoryDetail("RAG", "The resource"), "vector_store": _CategoryDetail("Vector Store", "The vector store resource"), + "database": _CategoryDetail("Database", "Interact with the database"), "example": _CategoryDetail("Example", "The example resource"), } @@ -219,6 +227,7 @@ class ResourceCategory(str, Enum): EMBEDDINGS = "embeddings" RAG = "rag" VECTOR_STORE = "vector_store" + DATABASE = "database" EXAMPLE = "example" def label(self) -> str: @@ -372,32 +381,41 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: "value": values.get("value"), "default": values.get("default"), } + is_list = values.get("is_list") or False if type_cls: for k, v in to_handle_values.items(): if v: - handled_v = cls._covert_to_real_type(type_cls, v) + handled_v = cls._covert_to_real_type(type_cls, v, is_list) values[k] = handled_v return values @classmethod - def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any: - if type_cls and v is not None: - typed_value: Any = v + def _covert_to_real_type(cls, type_cls: str, v: Any, is_list: bool) -> Any: + def _parse_single_value(vv: Any) -> Any: + typed_value: Any = vv try: # Try to convert the value to the type. if type_cls == "builtins.str": - typed_value = str(v) + typed_value = str(vv) elif type_cls == "builtins.int": - typed_value = int(v) + typed_value = int(vv) elif type_cls == "builtins.float": - typed_value = float(v) + typed_value = float(vv) elif type_cls == "builtins.bool": - if str(v).lower() in ["false", "0", "", "no", "off"]: + if str(vv).lower() in ["false", "0", "", "no", "off"]: return False - typed_value = bool(v) + typed_value = bool(vv) return typed_value except ValueError: - raise ValidationError(f"Value '{v}' is not valid for type {type_cls}") + raise ValidationError(f"Value '{vv}' is not valid for type {type_cls}") + + if type_cls and v is not None: + if not is_list: + _parse_single_value(v) + else: + if not isinstance(v, list): + raise ValidationError(f"Value '{v}' is not a list.") + return [_parse_single_value(vv) for vv in v] return v def get_typed_value(self) -> Any: @@ -413,11 +431,11 @@ def get_typed_value(self) -> Any: if is_variables and self.value is not None and isinstance(self.value, str): return VariablesPlaceHolder(self.name, self.value) else: - return self._covert_to_real_type(self.type_cls, self.value) + return self._covert_to_real_type(self.type_cls, self.value, self.is_list) def get_typed_default(self) -> Any: """Get the typed default.""" - return self._covert_to_real_type(self.type_cls, self.default) + return self._covert_to_real_type(self.type_cls, self.default, self.is_list) @classmethod def build_from( @@ -499,7 +517,10 @@ def to_dict(self) -> Dict: values = self.options.option_values() dict_value["options"] = [value.to_dict() for value in values] else: - dict_value["options"] = [value.to_dict() for value in self.options] + dict_value["options"] = [ + value.to_dict() if not isinstance(value, dict) else value + for value in self.options + ] if self.ui: dict_value["ui"] = self.ui.to_dict() @@ -594,6 +615,17 @@ def to_runnable_parameter( value = view_value return {self.name: value} + def new(self: TM) -> TM: + """Copy the metadata.""" + new_obj = self.__class__( + **self.model_dump(exclude_defaults=True, exclude={"ui", "options"}) + ) + if self.ui: + new_obj.ui = self.ui + if self.options: + new_obj.options = self.options + return new_obj + class BaseResource(Serializable, BaseModel): """The base resource.""" @@ -644,6 +676,17 @@ class IOField(Resource): description="Whether current field is list", examples=[True, False], ) + dynamic: bool = Field( + default=False, + description="Whether current field is dynamic", + examples=[True, False], + ) + dynamic_minimum: int = Field( + default=0, + description="The minimum count of the dynamic field, only valid when dynamic is" + " True", + examples=[0, 1, 2], + ) @classmethod def build_from( @@ -653,6 +696,8 @@ def build_from( type: Type, description: Optional[str] = None, is_list: bool = False, + dynamic: bool = False, + dynamic_minimum: int = 0, ): """Build the resource from the type.""" type_name = type.__qualname__ @@ -664,8 +709,22 @@ def build_from( type_cls=type_cls, is_list=is_list, description=description or label, + dynamic=dynamic, + dynamic_minimum=dynamic_minimum, ) + @model_validator(mode="before") + @classmethod + def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Pre fill the metadata.""" + if not isinstance(values, dict): + return values + if "dynamic" not in values: + values["dynamic"] = False + if "dynamic_minimum" not in values: + values["dynamic_minimum"] = 0 + return values + class BaseMetadata(BaseResource): """The base metadata.""" @@ -808,9 +867,40 @@ def get_origin_id(self) -> str: split_ids = self.id.split("_") return "_".join(split_ids[:-1]) + def _parse_ui_size(self) -> Optional[str]: + """Parse the ui size.""" + if not self.parameters: + return None + parameters_size = set() + for parameter in self.parameters: + if parameter.ui and parameter.ui.size: + parameters_size.add(parameter.ui.size) + for size in ["large", "middle", "small"]: + if size in parameters_size: + return size + return None + def to_dict(self) -> Dict: """Convert current metadata to json dict.""" + from .ui import _size_to_order + dict_value = model_to_dict(self, exclude={"parameters"}) + tags = dict_value.get("tags") + if not tags: + tags = {"ui_version": "flow2.0"} + elif isinstance(tags, dict) and "ui_version" not in tags: + tags["ui_version"] = "flow2.0" + + parsed_ui_size = self._parse_ui_size() + if parsed_ui_size: + exist_size = tags.get("ui_size") + if not exist_size or _size_to_order(parsed_ui_size) > _size_to_order( + exist_size + ): + # Use the higher order size as current size. + tags["ui_size"] = parsed_ui_size + + dict_value["tags"] = tags dict_value["parameters"] = [ parameter.to_dict() for parameter in self.parameters ] diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index 87b828971..69e729ef7 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -97,6 +97,12 @@ def parse_data(cls, value: Any): return ResourceMetadata(**value) raise ValueError("Unable to infer the type for `data`") + def to_dict(self) -> Dict[str, Any]: + """Convert to dict.""" + dict_value = model_to_dict(self, exclude={"data"}) + dict_value["data"] = self.data.to_dict() + return dict_value + class FlowEdgeData(BaseModel): """Edge data in a flow.""" @@ -166,6 +172,12 @@ class FlowData(BaseModel): edges: List[FlowEdgeData] = Field(..., description="Edges in the flow") viewport: FlowPositionData = Field(..., description="Viewport of the flow") + def to_dict(self) -> Dict[str, Any]: + """Convert to dict.""" + dict_value = model_to_dict(self, exclude={"nodes"}) + dict_value["nodes"] = [n.to_dict() for n in self.nodes] + return dict_value + class _VariablesRequestBase(BaseModel): key: str = Field( @@ -518,9 +530,24 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["name"] = name return values + def model_dump(self, **kwargs): + """Override the model dump method.""" + exclude = kwargs.get("exclude", set()) + if "flow_dag" not in exclude: + exclude.add("flow_dag") + if "flow_data" not in exclude: + exclude.add("flow_data") + kwargs["exclude"] = exclude + common_dict = super().model_dump(**kwargs) + if self.flow_dag: + common_dict["flow_dag"] = None + if self.flow_data: + common_dict["flow_data"] = self.flow_data.to_dict() + return common_dict + def to_dict(self) -> Dict[str, Any]: """Convert to dict.""" - return model_to_dict(self, exclude={"flow_dag"}) + return model_to_dict(self, exclude={"flow_dag", "flow_data"}) def get_variables_dict(self) -> List[Dict[str, Any]]: """Get the variables dict.""" @@ -568,6 +595,11 @@ def build(self, flow_panel: FlowPanel) -> DAG: key_to_resource_nodes[key] = node key_to_resource[key] = node.data + if not key_to_operator_nodes and not key_to_resource_nodes: + raise FlowMetadataException( + "No operator or resource nodes found in the flow." + ) + for edge in flow_data.edges: source_key = edge.source target_key = edge.target @@ -943,11 +975,17 @@ def fill_flow_panel(flow_panel: FlowPanel): new_param = input_parameters[i.name] i.label = new_param.label i.description = new_param.description + i.dynamic = new_param.dynamic + i.is_list = new_param.is_list + i.dynamic_minimum = new_param.dynamic_minimum for i in node.data.outputs: if i.name in output_parameters: new_param = output_parameters[i.name] i.label = new_param.label i.description = new_param.description + i.dynamic = new_param.dynamic + i.is_list = new_param.is_list + i.dynamic_minimum = new_param.dynamic_minimum else: data = cast(ResourceMetadata, node.data) key = data.get_origin_id() @@ -972,6 +1010,8 @@ def fill_flow_panel(flow_panel: FlowPanel): param.options = new_param.get_dict_options() # type: ignore param.default = new_param.default param.placeholder = new_param.placeholder + param.alias = new_param.alias + param.ui = new_param.ui except (FlowException, ValueError) as e: logger.warning(f"Unable to fill the flow panel: {e}") diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 928755a20..efe3d05e0 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Literal, Optional, Union -from dbgpt._private.pydantic import BaseModel, Field, model_to_dict +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict, model_validator from dbgpt.core.interface.serialization import Serializable from .exceptions import FlowUIComponentException @@ -25,6 +25,16 @@ "code_editor", ] +_UI_SIZE_TYPE = Literal["large", "middle", "small"] +_SIZE_ORDER = {"large": 6, "middle": 4, "small": 2} + + +def _size_to_order(size: str) -> int: + """Convert size to order.""" + if size not in _SIZE_ORDER: + return -1 + return _SIZE_ORDER[size] + class RefreshableMixin(BaseModel): """Refreshable mixin.""" @@ -81,6 +91,10 @@ class UIAttribute(BaseModel): ) ui_type: _UI_TYPE = Field(..., description="UI component type") + size: Optional[_UI_SIZE_TYPE] = Field( + None, + description="The size of the component(small, middle, large)", + ) attr: Optional[UIAttribute] = Field( None, @@ -266,6 +280,27 @@ class AutoSize(BaseModel): description="The attributes of the component", ) + @model_validator(mode="after") + def check_size(self) -> "UITextArea": + """Check the size. + + Automatically set the size to large if the max_rows is greater than 10. + """ + attr = self.attr + auto_size = attr.auto_size if attr else None + if not attr or not auto_size or isinstance(auto_size, bool): + return self + max_rows = ( + auto_size.max_rows + if isinstance(auto_size, self.UIAttribute.AutoSize) + else None + ) + size = self.size + if not size and max_rows and max_rows > 10: + # Automatically set the size to large if the max_rows is greater than 10 + self.size = "large" + return self + class UIAutoComplete(UIInput): """Auto complete component.""" @@ -450,7 +485,7 @@ class DefaultUITextArea(UITextArea): attr: Optional[UITextArea.UIAttribute] = Field( default_factory=lambda: UITextArea.UIAttribute( - auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40) + auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=20) ), description="The attributes of the component", ) diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 8f0298297..33692a423 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -29,6 +29,7 @@ from ..dag.base import DAG from ..flow import ( + TAGS_ORDER_HIGH, IOField, OperatorCategory, OperatorType, @@ -965,6 +966,7 @@ class CommonLLMHttpTrigger(HttpTrigger): _PARAMETER_MEDIA_TYPE.new(), _PARAMETER_STATUS_CODE.new(), ], + tags={"order": TAGS_ORDER_HIGH}, ) def __init__( @@ -1203,6 +1205,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]): "User input parsed operator, parse the user input from request body and " "return as a string" ), + tags={"order": TAGS_ORDER_HIGH}, ) def __init__(self, key: str = "user_input", **kwargs): diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index e6a5d24d4..94de92a03 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -195,6 +195,9 @@ class ModelRequest: temperature: Optional[float] = None """The temperature of the model inference.""" + top_p: Optional[float] = None + """The top p of the model inference.""" + max_new_tokens: Optional[int] = None """The maximum number of tokens to generate.""" diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 50a7b39e5..f67b83cb8 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -317,6 +317,25 @@ def messages_to_string( """ return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix) + @staticmethod + def parse_user_message(messages: List[ModelMessage]) -> str: + """Parse user message from messages. + + Args: + messages (List[ModelMessage]): The all messages in the conversation. + + Returns: + str: The user message + """ + lass_user_message = None + for message in messages[::-1]: + if message.role == ModelMessageRoleType.HUMAN: + lass_user_message = message.content + break + if not lass_user_message: + raise ValueError("No user message") + return lass_user_message + _SingleRoundMessage = List[BaseMessage] _MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]] @@ -1244,9 +1263,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]: content=ai_message.content, index=ai_message.index, round_index=ai_message.round_index, - additional_kwargs=ai_message.additional_kwargs.copy() - if ai_message.additional_kwargs - else {}, + additional_kwargs=( + ai_message.additional_kwargs.copy() + if ai_message.additional_kwargs + else {} + ), ) current_round.append(view_message) return sum(messages_by_round, []) diff --git a/dbgpt/core/interface/operators/llm_operator.py b/dbgpt/core/interface/operators/llm_operator.py index 45863d0a9..628c2f59f 100644 --- a/dbgpt/core/interface/operators/llm_operator.py +++ b/dbgpt/core/interface/operators/llm_operator.py @@ -246,10 +246,16 @@ class BaseLLM: SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name" SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output" + SHARE_DATA_KEY_MODEL_OUTPUT_VIEW = "share_data_key_model_output_view" - def __init__(self, llm_client: Optional[LLMClient] = None): + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + ): """Create a new LLM operator.""" self._llm_client = llm_client + self._save_model_output = save_model_output @property def llm_client(self) -> LLMClient: @@ -262,9 +268,10 @@ async def save_model_output( self, current_dag_context: DAGContext, model_output: ModelOutput ) -> None: """Save the model output to the share data.""" - await current_dag_context.save_to_share_data( - self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output - ) + if self._save_model_output: + await current_dag_context.save_to_share_data( + self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output + ) class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): @@ -276,9 +283,14 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): This operator will generate a no streaming response. """ - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): """Create a new LLM operator.""" - super().__init__(llm_client=llm_client) + super().__init__(llm_client=llm_client, save_model_output=save_model_output) MapOperator.__init__(self, **kwargs) async def map(self, request: ModelRequest) -> ModelOutput: @@ -309,13 +321,18 @@ class BaseStreamingLLMOperator( This operator will generate streaming response. """ - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): """Create a streaming operator for a LLM. Args: llm_client (LLMClient, optional): The LLM client. Defaults to None. """ - super().__init__(llm_client=llm_client) + super().__init__(llm_client=llm_client, save_model_output=save_model_output) BaseOperator.__init__(self, **kwargs) async def streamify( # type: ignore diff --git a/dbgpt/core/interface/operators/prompt_operator.py b/dbgpt/core/interface/operators/prompt_operator.py index 7d97230ac..241d8915f 100644 --- a/dbgpt/core/interface/operators/prompt_operator.py +++ b/dbgpt/core/interface/operators/prompt_operator.py @@ -4,14 +4,10 @@ from typing import Any, Dict, List, Optional, Union from dbgpt._private.pydantic import model_validator -from dbgpt.core import ( - ModelMessage, - ModelMessageRoleType, - ModelOutput, - StorageConversation, -) +from dbgpt.core import ModelMessage, ModelOutput, StorageConversation from dbgpt.core.awel import JoinOperator, MapOperator from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, IOField, OperatorCategory, OperatorType, @@ -42,6 +38,7 @@ name="common_chat_prompt_template", category=ResourceCategory.PROMPT, description=_("The operator to build the prompt with static prompt."), + tags={"order": TAGS_ORDER_HIGH}, parameters=[ Parameter.build_from( label=_("System Message"), @@ -101,9 +98,10 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: class BasePromptBuilderOperator(BaseConversationOperator, ABC): """The base prompt builder operator.""" - def __init__(self, check_storage: bool, **kwargs): + def __init__(self, check_storage: bool, save_to_storage: bool = True, **kwargs): """Create a new prompt builder operator.""" super().__init__(check_storage=check_storage, **kwargs) + self._save_to_storage = save_to_storage async def format_prompt( self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any] @@ -122,8 +120,9 @@ async def format_prompt( pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables} messages = prompt.format_messages(**pass_kwargs) model_messages = ModelMessage.from_base_messages(messages) - # Start new round conversation, and save user message to storage - await self.start_new_round_conv(model_messages) + if self._save_to_storage: + # Start new round conversation, and save user message to storage + await self.start_new_round_conv(model_messages) return model_messages async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: @@ -132,13 +131,7 @@ async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: Args: messages (List[ModelMessage]): The messages. """ - lass_user_message = None - for message in messages[::-1]: - if message.role == ModelMessageRoleType.HUMAN: - lass_user_message = message.content - break - if not lass_user_message: - raise ValueError("No user message") + lass_user_message = ModelMessage.parse_user_message(messages) storage_conv: Optional[ StorageConversation ] = await self.get_storage_conversation() @@ -150,6 +143,8 @@ async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: async def after_dag_end(self, event_loop_task_id: int): """Execute after the DAG finished.""" + if not self._save_to_storage: + return # Save the storage conversation to storage after the whole DAG finished storage_conv: Optional[ StorageConversation @@ -422,7 +417,7 @@ def __init__( self._prompt = prompt self._history_key = history_key self._str_history = str_history - BasePromptBuilderOperator.__init__(self, check_storage=check_storage) + BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) @rearrange_args_by_type @@ -455,7 +450,7 @@ def __init__( """Create a new history dynamic prompt builder operator.""" self._history_key = history_key self._str_history = str_history - BasePromptBuilderOperator.__init__(self, check_storage=check_storage) + BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) @rearrange_args_by_type diff --git a/dbgpt/core/interface/output_parser.py b/dbgpt/core/interface/output_parser.py index faf29bfff..31e91b9f3 100644 --- a/dbgpt/core/interface/output_parser.py +++ b/dbgpt/core/interface/output_parser.py @@ -13,7 +13,13 @@ from dbgpt.core import ModelOutput from dbgpt.core.awel import MapOperator -from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + OperatorType, + ViewMetadata, +) from dbgpt.util.i18n_utils import _ T = TypeVar("T") @@ -271,7 +277,7 @@ async def map(self, input_value: ModelOutput) -> Any: if self.current_dag_context.streaming_call: return self.parse_model_stream_resp_ex(input_value, 0) else: - return self.parse_model_nostream_resp(input_value, "###") + return self.parse_model_nostream_resp(input_value, "#####################") def _parse_model_response(response: ResponseTye): @@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye): class SQLOutputParser(BaseOutputParser): """Parse the SQL output of an LLM call.""" + metadata = ViewMetadata( + label=_("SQL Output Parser"), + name="default_sql_output_parser", + category=OperatorCategory.OUTPUT_PARSER, + description=_("Parse the SQL output of an LLM call."), + parameters=[], + inputs=[ + IOField.build_from( + _("Model Output"), + "model_output", + ModelOutput, + description=_("The model output of upstream."), + ) + ], + outputs=[ + IOField.build_from( + _("Dict SQL Output"), + "dict", + dict, + description=_("The dict output after parsing."), + ) + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + def __init__(self, is_stream_out: bool = False, **kwargs): """Create a new SQL output parser.""" super().__init__(is_stream_out=is_stream_out, **kwargs) @@ -302,3 +333,57 @@ def parse_model_nostream_resp(self, response: ResponseTye, sep: str): model_out_text = super().parse_model_nostream_resp(response, sep) clean_str = super().parse_prompt_response(model_out_text) return json.loads(clean_str, strict=True) + + +class SQLListOutputParser(BaseOutputParser): + """Parse the SQL list output of an LLM call.""" + + metadata = ViewMetadata( + label=_("SQL List Output Parser"), + name="default_sql_list_output_parser", + category=OperatorCategory.OUTPUT_PARSER, + description=_( + "Parse the SQL list output of an LLM call, mostly used for dashboard." + ), + parameters=[], + inputs=[ + IOField.build_from( + _("Model Output"), + "model_output", + ModelOutput, + description=_("The model output of upstream."), + ) + ], + outputs=[ + IOField.build_from( + _("List SQL Output"), + "list", + dict, + is_list=True, + description=_("The list output after parsing."), + ) + ], + tags={"order": TAGS_ORDER_HIGH}, + ) + + def __init__(self, is_stream_out: bool = False, **kwargs): + """Create a new SQL list output parser.""" + super().__init__(is_stream_out=is_stream_out, **kwargs) + + def parse_model_nostream_resp(self, response: ResponseTye, sep: str): + """Parse the output of an LLM call.""" + from dbgpt.util.json_utils import find_json_objects + + model_out_text = super().parse_model_nostream_resp(response, sep) + json_objects = find_json_objects(model_out_text) + json_count = len(json_objects) + if json_count < 1: + raise ValueError("Unable to obtain valid output.") + + parsed_json_list = json_objects[0] + if not isinstance(parsed_json_list, list): + if isinstance(parsed_json_list, dict): + return [parsed_json_list] + else: + raise ValueError("Invalid output format.") + return parsed_json_list diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py index 99c4b9b10..d1d025d0a 100644 --- a/dbgpt/core/interface/prompt.py +++ b/dbgpt/core/interface/prompt.py @@ -254,6 +254,18 @@ def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["input_variables"] = sorted(input_variables) return values + def get_placeholders(self) -> List[str]: + """Get all placeholders in the prompt template. + + Returns: + List[str]: The placeholders. + """ + placeholders = set() + for message in self.messages: + if isinstance(message, MessagesPlaceholder): + placeholders.add(message.variable_name) + return sorted(placeholders) + @dataclasses.dataclass class PromptTemplateIdentifier(ResourceIdentifier): diff --git a/dbgpt/model/cluster/client.py b/dbgpt/model/cluster/client.py index 7e0aa0214..d58645cb7 100644 --- a/dbgpt/model/cluster/client.py +++ b/dbgpt/model/cluster/client.py @@ -42,13 +42,13 @@ class DefaultLLMClient(LLMClient): Args: worker_manager (WorkerManager): worker manager instance. - auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to False. + auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to True. """ def __init__( self, worker_manager: Optional[WorkerManager] = None, - auto_convert_message: bool = False, + auto_convert_message: bool = True, ): self._worker_manager = worker_manager self._auto_covert_message = auto_convert_message diff --git a/dbgpt/model/operators/llm_operator.py b/dbgpt/model/operators/llm_operator.py index 56eee1e3e..02f14fe73 100644 --- a/dbgpt/model/operators/llm_operator.py +++ b/dbgpt/model/operators/llm_operator.py @@ -24,8 +24,13 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC): This class extends BaseOperator by adding LLM capabilities. """ - def __init__(self, default_client: Optional[LLMClient] = None, **kwargs): - super().__init__(default_client) + def __init__( + self, + default_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): + super().__init__(default_client, save_model_output=save_model_output) @property def llm_client(self) -> LLMClient: @@ -95,8 +100,13 @@ class LLMOperator(MixinLLMOperator, BaseLLMOperator): ], ) - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): + super().__init__(llm_client, save_model_output=save_model_output) BaseLLMOperator.__init__(self, llm_client, **kwargs) @@ -144,6 +154,11 @@ class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator): ], ) - def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): - super().__init__(llm_client) + def __init__( + self, + llm_client: Optional[LLMClient] = None, + save_model_output: bool = True, + **kwargs, + ): + super().__init__(llm_client, save_model_output=save_model_output) BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs) diff --git a/dbgpt/model/utils/chatgpt_utils.py b/dbgpt/model/utils/chatgpt_utils.py index 057a04bf5..51c0fcae3 100644 --- a/dbgpt/model/utils/chatgpt_utils.py +++ b/dbgpt/model/utils/chatgpt_utils.py @@ -16,7 +16,13 @@ from dbgpt._private.pydantic import model_to_json from dbgpt.core.awel import TransformStreamAbsOperator -from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + IOField, + OperatorCategory, + OperatorType, + ViewMetadata, +) from dbgpt.core.interface.llm import ModelOutput from dbgpt.core.operators import BaseLLM from dbgpt.util.i18n_utils import _ @@ -184,6 +190,7 @@ class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str] ), ) ], + tags={"order": TAGS_ORDER_HIGH}, ) async def transform_stream(self, model_output: AsyncIterator[ModelOutput]): diff --git a/dbgpt/rag/summary/db_summary_client.py b/dbgpt/rag/summary/db_summary_client.py index de5ee83ff..8ce9a79e6 100644 --- a/dbgpt/rag/summary/db_summary_client.py +++ b/dbgpt/rag/summary/db_summary_client.py @@ -2,6 +2,7 @@ import logging import traceback +from typing import List from dbgpt._private.config import Config from dbgpt.component import SystemApp @@ -46,7 +47,7 @@ def db_summary_embedding(self, dbname, db_type): logger.info("db summary embedding success") - def get_db_summary(self, dbname, query, topk): + def get_db_summary(self, dbname, query, topk) -> List[str]: """Get user query related tables info.""" from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.storage.vector_store.base import VectorStoreConfig diff --git a/dbgpt/serve/agent/resource/datasource.py b/dbgpt/serve/agent/resource/datasource.py index 5e37cdd0c..0be2127dd 100644 --- a/dbgpt/serve/agent/resource/datasource.py +++ b/dbgpt/serve/agent/resource/datasource.py @@ -3,14 +3,41 @@ from typing import Any, List, Optional, Type, Union, cast from dbgpt._private.config import Config -from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource +from dbgpt.agent.resource.database import ( + _DEFAULT_PROMPT_TEMPLATE, + _DEFAULT_PROMPT_TEMPLATE_ZH, + DBParameters, + RDBMSConnectorResource, +) +from dbgpt.core.awel.flow import ( + TAGS_ORDER_HIGH, + FunctionDynamicOptions, + OptionValue, + Parameter, + ResourceCategory, + register_resource, +) from dbgpt.util import ParameterDescription +from dbgpt.util.i18n_utils import _ CFG = Config() logger = logging.getLogger(__name__) +def _load_datasource() -> List[OptionValue]: + dbs = CFG.local_db_manager.get_db_list() + results = [ + OptionValue( + label="[" + db["db_type"] + "]" + db["db_name"], + name=db["db_name"], + value=db["db_name"], + ) + for db in dbs + ] + return results + + @dataclasses.dataclass class DatasourceDBParameters(DBParameters): """The DB parameters for the datasource.""" @@ -57,6 +84,44 @@ def from_dict( return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields) +@register_resource( + _("Datasource Resource"), + "datasource", + category=ResourceCategory.DATABASE, + description=_( + "Connect to a datasource(retrieve table schemas and execute SQL to fetch data)." + ), + tags={"order": TAGS_ORDER_HIGH}, + parameters=[ + Parameter.build_from( + _("Datasource Name"), + "name", + str, + optional=True, + default="datasource", + description=_("The name of the datasource, default is 'datasource'."), + ), + Parameter.build_from( + _("DB Name"), + "db_name", + str, + description=_("The name of the database."), + options=FunctionDynamicOptions(func=_load_datasource), + ), + Parameter.build_from( + _("Prompt Template"), + "prompt_template", + str, + optional=True, + default=( + _DEFAULT_PROMPT_TEMPLATE_ZH + if CFG.LANGUAGE == "zh" + else _DEFAULT_PROMPT_TEMPLATE + ), + description=_("The prompt template to build a database prompt."), + ), + ], +) class DatasourceResource(RDBMSConnectorResource): def __init__(self, name: str, db_name: Optional[str] = None, **kwargs): conn = CFG.local_db_manager.get_connector(db_name) diff --git a/dbgpt/serve/agent/resource/knowledge.py b/dbgpt/serve/agent/resource/knowledge.py index 90359be4c..65c062415 100644 --- a/dbgpt/serve/agent/resource/knowledge.py +++ b/dbgpt/serve/agent/resource/knowledge.py @@ -64,6 +64,7 @@ class KnowledgeSpaceRetrieverResource(RetrieverResource): """Knowledge Space retriever resource.""" def __init__(self, name: str, space_name: str, context: Optional[dict] = None): + # TODO: Build the retriever in a thread pool, it will block the event loop retriever = KnowledgeSpaceRetriever( space_id=space_name, top_k=context.get("top_k", None) if context else 4, diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 936e0ff0f..ff8bf1326 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -133,7 +133,10 @@ async def create( Returns: ServerResponse: The response """ - return Result.succ(service.create_and_save_dag(request)) + res = await blocking_func_to_async( + global_system_app, service.create_and_save_dag, request + ) + return Result.succ(res) @router.put( @@ -154,7 +157,10 @@ async def update( ServerResponse: The response """ try: - return Result.succ(service.update_flow(request)) + res = await blocking_func_to_async( + global_system_app, service.update_flow, request + ) + return Result.succ(res) except Exception as e: return Result.failed(msg=str(e)) @@ -176,9 +182,7 @@ async def delete( @router.get("/flows/{uid}") -async def get_flows( - uid: str, service: Service = Depends(get_service) -) -> Result[ServerResponse]: +async def get_flows(uid: str, service: Service = Depends(get_service)): """Get a Flow entity by uid Args: @@ -191,7 +195,7 @@ async def get_flows( flow = service.get({"uid": uid}) if not flow: raise HTTPException(status_code=404, detail=f"Flow {uid} not found") - return Result.succ(flow) + return Result.succ(flow.model_dump()) @router.get( @@ -467,7 +471,10 @@ async def import_flow( status_code=400, detail=f"invalid file extension {file_extension}" ) if save_flow: - return Result.succ(service.create_and_save_dag(flow)) + res = await blocking_func_to_async( + global_system_app, service.create_and_save_dag, flow + ) + return Result.succ(res) else: return Result.succ(flow) diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 15c9d5ceb..3aac0b24c 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -27,7 +27,7 @@ ChatCompletionStreamResponse, DeltaMessage, ) -from dbgpt.serve.core import BaseService +from dbgpt.serve.core import BaseService, blocking_func_to_async from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata._base_dao import QUERY_SPEC from dbgpt.util.dbgpts.loader import DBGPTsLoader @@ -590,7 +590,11 @@ async def debug_flow( """ from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata - dag = self._flow_factory.build(request.flow) + dag = await blocking_func_to_async( + self._system_app, + self._flow_factory.build, + request.flow, + ) leaf_nodes = dag.leaf_nodes if len(leaf_nodes) != 1: raise ValueError("Chat Flow just support one leaf node in dag") diff --git a/dbgpt/serve/rag/operators/knowledge_space.py b/dbgpt/serve/rag/operators/knowledge_space.py index c37495ed5..3d2e1d846 100644 --- a/dbgpt/serve/rag/operators/knowledge_space.py +++ b/dbgpt/serve/rag/operators/knowledge_space.py @@ -223,7 +223,7 @@ def __init__( self._prompt = prompt self._history_key = history_key self._str_history = str_history - BasePromptBuilderOperator.__init__(self, check_storage=check_storage) + BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_context, **kwargs) @rearrange_args_by_type From bf63a967b51ad0ee880071335229d393eaeae809 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Wed, 28 Aug 2024 16:48:50 +0800 Subject: [PATCH 09/60] feat: Support endpoint placeholder --- dbgpt/core/awel/trigger/http_trigger.py | 51 ++++++++++++++-------- dbgpt/core/awel/trigger/trigger_manager.py | 9 ++-- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 33692a423..6e17be15e 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -58,6 +58,8 @@ logger = logging.getLogger(__name__) +ENDPOINT_PLACEHOLDER_DAG_ID = "{dag_id}" + class AWELHttpError(RuntimeError): """AWEL Http Error.""" @@ -465,14 +467,11 @@ def mount_to_router( router (APIRouter): The router to mount the trigger. global_prefix (Optional[str], optional): The global prefix of the router. """ - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint dynamic_route_function = self._create_route_func() router.api_route( - self._endpoint, + endpoint, methods=self._methods, response_model=self._response_model, status_code=self._status_code, @@ -498,11 +497,9 @@ def mount_to_app( """ from dbgpt.util.fastapi import PriorityAPIRouter - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint dynamic_route_function = self._create_route_func() router = cast(PriorityAPIRouter, app.router) router.add_api_route( @@ -533,17 +530,28 @@ def remove_from_app( """ from fastapi import APIRouter - path = ( - join_paths(global_prefix, self._endpoint) - if global_prefix - else self._endpoint - ) + endpoint = self._resolved_endpoint() + + path = join_paths(global_prefix, endpoint) if global_prefix else endpoint app_router = cast(APIRouter, app.router) for i, r in enumerate(app_router.routes): if r.path_format == path: # type: ignore # TODO, remove with path and methods del app_router.routes[i] + def _resolved_endpoint(self) -> str: + """Get the resolved endpoint. + + Replace the placeholder {dag_id} with the real dag_id. + """ + endpoint = self._endpoint + if ENDPOINT_PLACEHOLDER_DAG_ID not in endpoint: + return endpoint + if not self.dag: + raise AWELHttpError("DAG is not set") + dag_id = self.dag.dag_id + return endpoint.replace(ENDPOINT_PLACEHOLDER_DAG_ID, dag_id) + def _trigger_mode(self) -> str: if ( self._req_body @@ -959,7 +967,14 @@ class CommonLLMHttpTrigger(HttpTrigger): ), ], parameters=[ - _PARAMETER_ENDPOINT.new(), + Parameter.build_from( + _("API Endpoint"), + "endpoint", + str, + optional=True, + default="/example/" + ENDPOINT_PLACEHOLDER_DAG_ID, + description=_("The API endpoint"), + ), _PARAMETER_METHODS_POST_PUT.new(), _PARAMETER_STREAMING_RESPONSE.new(), _PARAMETER_RESPONSE_BODY.new(), @@ -971,7 +986,7 @@ class CommonLLMHttpTrigger(HttpTrigger): def __init__( self, - endpoint: str, + endpoint: str = "/example/" + ENDPOINT_PLACEHOLDER_DAG_ID, methods: Optional[Union[str, List[str]]] = "POST", streaming_response: bool = False, http_response_body: Optional[Type[BaseHttpBody]] = None, diff --git a/dbgpt/core/awel/trigger/trigger_manager.py b/dbgpt/core/awel/trigger/trigger_manager.py index 45b040147..94563226e 100644 --- a/dbgpt/core/awel/trigger/trigger_manager.py +++ b/dbgpt/core/awel/trigger/trigger_manager.py @@ -81,7 +81,8 @@ def register_trigger( raise ValueError(f"Current trigger {trigger} not an object of HttpTrigger") trigger_id = trigger.node_id if trigger_id not in self._trigger_map: - path = join_paths(self._router_prefix, trigger._endpoint) + real_endpoint = trigger._resolved_endpoint() + path = join_paths(self._router_prefix, real_endpoint) methods = trigger._methods # Check whether the route is already registered self._register_route_tables(path, methods) @@ -116,9 +117,9 @@ def unregister_trigger(self, trigger: Any, system_app: SystemApp) -> None: if not app: raise ValueError("System app not initialized") trigger.remove_from_app(app, self._router_prefix) - self._unregister_route_tables( - join_paths(self._router_prefix, trigger._endpoint), trigger._methods - ) + real_endpoint = trigger._resolved_endpoint() + path = join_paths(self._router_prefix, real_endpoint) + self._unregister_route_tables(path, trigger._methods) del self._trigger_map[trigger_id] def _init_app(self, system_app: SystemApp): From 0e71991f7ef2fce5c08847b9e98c1087e42d7b79 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 30 Aug 2024 15:02:53 +0800 Subject: [PATCH 10/60] chore: Merge latest code --- dbgpt/app/operators/datasource.py | 29 +++++++- dbgpt/app/operators/rag.py | 19 +++++ dbgpt/core/awel/dag/base.py | 15 ++-- dbgpt/core/awel/flow/base.py | 11 +++ dbgpt/core/awel/flow/flow_factory.py | 89 ++++++++++++++++++++---- dbgpt/core/awel/trigger/http_trigger.py | 20 ++++++ dbgpt/serve/flow/service/service.py | 4 +- examples/awel/awel_flow_ui_components.py | 4 +- 8 files changed, 169 insertions(+), 22 deletions(-) diff --git a/dbgpt/app/operators/datasource.py b/dbgpt/app/operators/datasource.py index 7fe16feaa..320df8d22 100644 --- a/dbgpt/app/operators/datasource.py +++ b/dbgpt/app/operators/datasource.py @@ -4,6 +4,7 @@ from dbgpt._private.config import Config from dbgpt.agent.resource.database import DBResource +from dbgpt.core import Chunk from dbgpt.core.awel import DAGContext, MapOperator from dbgpt.core.awel.flow import ( TAGS_ORDER_HIGH, @@ -193,6 +194,19 @@ async def save_view_message(self, dag_ctx: DAGContext, view: str): class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]): """Retrieve the table schemas from the datasource.""" + _share_data_key = "__datasource_retriever_chunks__" + + class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]): + async def map(self, context: HOContextBody) -> List[Chunk]: + schema_info = await self.current_dag_context.get_from_share_data( + HODatasourceRetrieverOperator._share_data_key + ) + if isinstance(schema_info, list): + chunks = [Chunk(content=table_info) for table_info in schema_info] + else: + chunks = [Chunk(content=schema_info)] + return chunks + metadata = ViewMetadata( label=_("Datasource Retriever Operator"), name="higher_order_datasource_retriever_operator", @@ -207,7 +221,17 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]): _PARAMETER_CONTEXT_KEY.new(), ], inputs=[_INPUTS_QUESTION.new()], - outputs=[_OUTPUTS_CONTEXT.new()], + outputs=[ + _OUTPUTS_CONTEXT.new(), + IOField.build_from( + _("Retrieved schema chunks"), + "chunks", + Chunk, + is_list=True, + description=_("The retrieved schema chunks from the datasource"), + mappers=[ChunkMapper], + ), + ], tags={"order": TAGS_ORDER_HIGH}, ) @@ -239,6 +263,9 @@ async def map(self, question: str) -> HOContextBody: db=db_name, question=question, ) + await self.current_dag_context.save_to_share_data( + self._share_data_key, schema_info + ) context = self._prompt_template.format( db_name=db_name, table_info=schema_info, diff --git a/dbgpt/app/operators/rag.py b/dbgpt/app/operators/rag.py index 79d166ac0..d7fa75b24 100644 --- a/dbgpt/app/operators/rag.py +++ b/dbgpt/app/operators/rag.py @@ -1,6 +1,7 @@ from typing import List, Optional from dbgpt._private.config import Config +from dbgpt.core import Chunk from dbgpt.core.awel import MapOperator from dbgpt.core.awel.flow import ( TAGS_ORDER_HIGH, @@ -93,6 +94,15 @@ def _load_space_name() -> List[OptionValue]: class HOKnowledgeOperator(MapOperator[str, HOContextBody]): + _share_data_key = "_higher_order_knowledge_operator_retriever_chunks" + + class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]): + async def map(self, context: HOContextBody) -> List[Chunk]: + chunks = await self.current_dag_context.get_from_share_data( + HOKnowledgeOperator._share_data_key + ) + return chunks + metadata = ViewMetadata( label=_("Knowledge Operator"), name="higher_order_knowledge_operator", @@ -122,6 +132,14 @@ class HOKnowledgeOperator(MapOperator[str, HOContextBody]): ], outputs=[ _OUTPUTS_CONTEXT.new(), + IOField.build_from( + _("Chunks"), + "chunks", + Chunk, + is_list=True, + description=_("The retrieved chunks from the knowledge space"), + mappers=[ChunkMapper], + ), ], tags={"order": TAGS_ORDER_HIGH}, ) @@ -185,6 +203,7 @@ async def map(self, query: str) -> HOContextBody: chunks = await self._space_retriever.aretrieve_with_scores( query, self._score_threshold ) + await self.current_dag_context.save_to_share_data(self._share_data_key, chunks) return HOContextBody( context_key=self._context_key, context=[chunk.content for chunk in chunks], diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index ffe6a7b0e..2f3521d24 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -619,6 +619,7 @@ def __init__( self._node_name_to_ids: Dict[str, str] = node_name_to_ids self._event_loop_task_id = event_loop_task_id self._dag_variables = dag_variables + self._share_data_lock = asyncio.Lock() @property def _task_outputs(self) -> Dict[str, TaskContext]: @@ -680,8 +681,9 @@ async def get_from_share_data(self, key: str) -> Any: Returns: Any: The share data, you can cast it to the real type """ - logger.debug(f"Get share data by key {key} from {id(self._share_data)}") - return self._share_data.get(key) + async with self._share_data_lock: + logger.debug(f"Get share data by key {key} from {id(self._share_data)}") + return self._share_data.get(key) async def save_to_share_data( self, key: str, data: Any, overwrite: bool = False @@ -694,10 +696,11 @@ async def save_to_share_data( overwrite (bool): Whether overwrite the share data if the key already exists. Defaults to None. """ - if key in self._share_data and not overwrite: - raise ValueError(f"Share data key {key} already exists") - logger.debug(f"Save share data by key {key} to {id(self._share_data)}") - self._share_data[key] = data + async with self._share_data_lock: + if key in self._share_data and not overwrite: + raise ValueError(f"Share data key {key} already exists") + logger.debug(f"Save share data by key {key} to {id(self._share_data)}") + self._share_data[key] = data async def get_task_share_data(self, task_name: str, key: str) -> Any: """Get share data by task name and key. diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index db8bbcb84..4e691ed08 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -687,6 +687,10 @@ class IOField(Resource): " True", examples=[0, 1, 2], ) + mappers: Optional[List[str]] = Field( + default=None, + description="The mappers of the field, transform the field to the target type", + ) @classmethod def build_from( @@ -698,10 +702,16 @@ def build_from( is_list: bool = False, dynamic: bool = False, dynamic_minimum: int = 0, + mappers: Optional[Union[Type, List[Type]]] = None, ): """Build the resource from the type.""" type_name = type.__qualname__ type_cls = _get_type_name(type) + # TODO: Check the mapper instance can be created without required + # parameters. + if mappers and not isinstance(mappers, list): + mappers = [mappers] + mappers_cls = [_get_type_name(m) for m in mappers] if mappers else None return cls( label=label, name=name, @@ -711,6 +721,7 @@ def build_from( description=description or label, dynamic=dynamic, dynamic_minimum=dynamic_minimum, + mappers=mappers_cls, ) @model_validator(mode="before") diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index 69e729ef7..9d157a998 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -1,10 +1,11 @@ """Build AWEL DAGs from serialized data.""" +import dataclasses import logging import uuid from contextlib import suppress from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast +from typing import Any, Dict, List, Literal, Optional, Type, Union, cast from typing_extensions import Annotated @@ -565,6 +566,17 @@ def parse_variables( return [FlowVariables(**v) for v in variables] +@dataclasses.dataclass +class _KeyToNodeItem: + """Key to node item.""" + + key: str + source_order: int + target_order: int + mappers: List[str] + edge_index: int + + class FlowFactory: """Flow factory.""" @@ -580,8 +592,10 @@ def build(self, flow_panel: FlowPanel) -> DAG: key_to_operator_nodes: Dict[str, FlowNodeData] = {} key_to_resource_nodes: Dict[str, FlowNodeData] = {} key_to_resource: Dict[str, ResourceMetadata] = {} - key_to_downstream: Dict[str, List[Tuple[str, int, int]]] = {} - key_to_upstream: Dict[str, List[Tuple[str, int, int]]] = {} + # Record current node's downstream + key_to_downstream: Dict[str, List[_KeyToNodeItem]] = {} + # Record current node's upstream + key_to_upstream: Dict[str, List[_KeyToNodeItem]] = {} key_to_upstream_node: Dict[str, List[FlowNodeData]] = {} for node in flow_data.nodes: key = node.id @@ -600,7 +614,7 @@ def build(self, flow_panel: FlowPanel) -> DAG: "No operator or resource nodes found in the flow." ) - for edge in flow_data.edges: + for edge_index, edge in enumerate(flow_data.edges): source_key = edge.source target_key = edge.target source_node: FlowNodeData | None = key_to_operator_nodes.get( @@ -620,12 +634,37 @@ def build(self, flow_panel: FlowPanel) -> DAG: if source_node.data.is_operator and target_node.data.is_operator: # Operator to operator. + mappers = [] + for i, out in enumerate(source_node.data.outputs): + if i != edge.source_order: + continue + if out.mappers: + # Current edge is a mapper edge, find the mappers. + mappers = out.mappers + # Note: Not support mappers in the inputs of the target node now. + downstream = key_to_downstream.get(source_key, []) - downstream.append((target_key, edge.source_order, edge.target_order)) + downstream.append( + _KeyToNodeItem( + key=target_key, + source_order=edge.source_order, + target_order=edge.target_order, + mappers=mappers, + edge_index=edge_index, + ) + ) key_to_downstream[source_key] = downstream upstream = key_to_upstream.get(target_key, []) - upstream.append((source_key, edge.source_order, edge.target_order)) + upstream.append( + _KeyToNodeItem( + key=source_key, + source_order=edge.source_order, + target_order=edge.target_order, + mappers=mappers, + edge_index=edge_index, + ) + ) key_to_upstream[target_key] = upstream elif not source_node.data.is_operator and target_node.data.is_operator: # Resource to operator. @@ -683,10 +722,10 @@ def build(self, flow_panel: FlowPanel) -> DAG: # Sort the keys by the order of the nodes. for key, value in key_to_downstream.items(): # Sort by source_order. - key_to_downstream[key] = sorted(value, key=lambda x: x[1]) + key_to_downstream[key] = sorted(value, key=lambda x: x.source_order) for key, value in key_to_upstream.items(): # Sort by target_order. - key_to_upstream[key] = sorted(value, key=lambda x: x[2]) + key_to_upstream[key] = sorted(value, key=lambda x: x.target_order) sorted_key_to_resource_nodes = list(key_to_resource_nodes.values()) sorted_key_to_resource_nodes = sorted( @@ -784,8 +823,8 @@ def build_dag( self, flow_panel: FlowPanel, key_to_tasks: Dict[str, DAGNode], - key_to_downstream: Dict[str, List[Tuple[str, int, int]]], - key_to_upstream: Dict[str, List[Tuple[str, int, int]]], + key_to_downstream: Dict[str, List[_KeyToNodeItem]], + key_to_upstream: Dict[str, List[_KeyToNodeItem]], dag_id: Optional[str] = None, ) -> DAG: """Build the DAG.""" @@ -832,7 +871,8 @@ def build_dag( # This upstream has been sorted according to the order in the downstream # So we just need to connect the task to the upstream. - for upstream_key, _, _ in upstream: + for up_item in upstream: + upstream_key = up_item.key # Just one direction. upstream_task = key_to_tasks.get(upstream_key) if not upstream_task: @@ -843,7 +883,13 @@ def build_dag( upstream_task.set_node_id(dag._new_node_id()) if upstream_task is None: raise ValueError("Unable to find upstream task.") - upstream_task >> task + tasks = _build_mapper_operators(dag, up_item.mappers) + tasks.append(task) + last_task = upstream_task + for t in tasks: + # Connect the task to the upstream task. + last_task >> t + last_task = t return dag def pre_load_requirements(self, flow_panel: FlowPanel): @@ -950,6 +996,23 @@ def _topological_sort( return key_to_order +def _build_mapper_operators(dag: DAG, mappers: List[str]) -> List[DAGNode]: + from .base import _get_type_cls + + tasks = [] + for mapper in mappers: + try: + mapper_cls = _get_type_cls(mapper) + task = mapper_cls() + if not task._node_id: + task.set_node_id(dag._new_node_id()) + tasks.append(task) + except Exception as e: + err_msg = f"Unable to build mapper task: {mapper}, error: {e}" + raise FlowMetadataException(err_msg) + return tasks + + def fill_flow_panel(flow_panel: FlowPanel): """Fill the flow panel with the latest metadata. @@ -978,6 +1041,7 @@ def fill_flow_panel(flow_panel: FlowPanel): i.dynamic = new_param.dynamic i.is_list = new_param.is_list i.dynamic_minimum = new_param.dynamic_minimum + i.mappers = new_param.mappers for i in node.data.outputs: if i.name in output_parameters: new_param = output_parameters[i.name] @@ -986,6 +1050,7 @@ def fill_flow_panel(flow_panel: FlowPanel): i.dynamic = new_param.dynamic i.is_list = new_param.is_list i.dynamic_minimum = new_param.dynamic_minimum + i.mappers = new_param.mappers else: data = cast(ResourceMetadata, node.data) key = data.get_origin_id() diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 6e17be15e..fd503f566 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -945,6 +945,16 @@ def __init__( class CommonLLMHttpTrigger(HttpTrigger): """Common LLM http trigger for AWEL.""" + class MessagesOutputMapper(MapOperator[CommonLLMHttpRequestBody, str]): + """Messages output mapper.""" + + async def map(self, request_body: CommonLLMHttpRequestBody) -> str: + """Map the request body to messages.""" + if isinstance(request_body.messages, str): + return request_body.messages + else: + raise ValueError("Messages to be transformed is not a string") + metadata = ViewMetadata( label=_("Common LLM Http Trigger"), name="common_llm_http_trigger", @@ -965,6 +975,16 @@ class CommonLLMHttpTrigger(HttpTrigger): "LLM http body" ), ), + IOField.build_from( + _("Request String Messages"), + "request_string_messages", + str, + description=_( + "The request string messages of the API endpoint, parsed from " + "'messages' field of the request body" + ), + mappers=[MessagesOutputMapper], + ), ], parameters=[ Parameter.build_from( diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 3aac0b24c..cc9c00341 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -388,7 +388,9 @@ def get_list_by_page( Returns: List[ServerResponse]: The response """ - page_result = self.dao.get_list_page(request, page, page_size) + page_result = self.dao.get_list_page( + request, page, page_size, desc_order_column=ServeEntity.gmt_modified.name + ) for item in page_result.items: metadata = self.dag_manager.get_dag_metadata( item.dag_id, alias_name=item.uid diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index db92ca09d..2db9607ea 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -852,7 +852,7 @@ class ExampleFlowUploadOperator(MapOperator[str, str]): ui=ui.UIUpload( max_file_size=1024 * 1024 * 100, up_event="button_click", - file_types=["image/*", "*.pdf"], + file_types=["image/*", ".pdf"], drag=True, attr=ui.UIUpload.UIAttribute(max_count=5), ), @@ -897,7 +897,7 @@ async def map(self, user_name: str) -> str: files_metadata = await self.blocking_func_to_async( self._parse_files_metadata, fsc ) - files_metadata_str = json.dumps(files_metadata, ensure_ascii=False) + files_metadata_str = json.dumps(files_metadata, ensure_ascii=False, indent=4) return "Your name is %s, and you files are %s." % ( user_name, files_metadata_str, From 93527e0b0493a708a361236887176e9e5e70fd30 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 29 Aug 2024 19:37:45 +0800 Subject: [PATCH 11/60] feat: Support variables query API --- dbgpt/core/interface/variables.py | 1 + dbgpt/model/cluster/client.py | 6 +- dbgpt/serve/flow/api/endpoints.py | 63 +++++ dbgpt/serve/flow/api/schemas.py | 13 +- dbgpt/serve/flow/api/variables_provider.py | 134 +++++++++++ dbgpt/serve/flow/service/service.py | 2 +- dbgpt/serve/flow/service/variables_service.py | 224 +++++++++++++++++- dbgpt/util/pagination_utils.py | 26 ++ dbgpt/util/tests/test_pagination_utils.py | 84 +++++++ 9 files changed, 544 insertions(+), 9 deletions(-) create mode 100644 dbgpt/util/tests/test_pagination_utils.py diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py index 22e035d52..5d538ad25 100644 --- a/dbgpt/core/interface/variables.py +++ b/dbgpt/core/interface/variables.py @@ -31,6 +31,7 @@ BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets" BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms" BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings" +# Not implemented yet BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers" BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources" BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents" diff --git a/dbgpt/model/cluster/client.py b/dbgpt/model/cluster/client.py index d58645cb7..f4141cc23 100644 --- a/dbgpt/model/cluster/client.py +++ b/dbgpt/model/cluster/client.py @@ -27,7 +27,7 @@ name="auto_convert_message", type=bool, optional=True, - default=False, + default=True, description=_( "Whether to auto convert the messages that are not supported " "by the LLM to a compatible format" @@ -128,7 +128,7 @@ async def count_token(self, model: str, prompt: str) -> int: name="auto_convert_message", type=bool, optional=True, - default=False, + default=True, description=_( "Whether to auto convert the messages that are not supported " "by the LLM to a compatible format" @@ -158,7 +158,7 @@ class RemoteLLMClient(DefaultLLMClient): def __init__( self, controller_address: str = "http://127.0.0.1:8000", - auto_convert_message: bool = False, + auto_convert_message: bool = True, ): """Initialize the RemoteLLMClient.""" from dbgpt.model.cluster import ModelRegistryClient, RemoteWorkerManager diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index ff8bf1326..a985b9ab1 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -21,6 +21,7 @@ RefreshNodeRequest, ServeRequest, ServerResponse, + VariablesKeyResponse, VariablesRequest, VariablesResponse, ) @@ -364,6 +365,62 @@ async def update_variables( return Result.succ(res) +@router.get( + "/variables", + response_model=Result[PaginationResult[VariablesResponse]], + dependencies=[Depends(check_api_key)], +) +async def get_variables_by_keys( + key: str = Query(..., description="variable key"), + scope: Optional[str] = Query(default=None, description="scope"), + scope_key: Optional[str] = Query(default=None, description="scope key"), + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + page: int = Query(default=1, description="current page"), + page_size: int = Query(default=20, description="page size"), +) -> Result[PaginationResult[VariablesResponse]]: + """Get the variables by keys + + Returns: + VariablesResponse: The response + """ + res = await get_variable_service().get_list_by_page( + key, + scope, + scope_key, + user_name, + sys_code, + page, + page_size, + ) + return Result.succ(res) + + +@router.get( + "/variables/keys", + response_model=Result[List[VariablesKeyResponse]], + dependencies=[Depends(check_api_key)], +) +async def get_variables_keys( + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + category: Optional[str] = Query(default=None, description="category"), +) -> Result[List[VariablesKeyResponse]]: + """Get the variable keys + + Returns: + VariablesKeyResponse: The response + """ + res = await blocking_func_to_async( + global_system_app, + get_variable_service().list_keys, + user_name, + sys_code, + category, + ) + return Result.succ(res) + + @router.post("/flow/debug", dependencies=[Depends(check_api_key)]) async def debug_flow( flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service) @@ -482,10 +539,13 @@ async def import_flow( def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" from .variables_provider import ( + BuiltinAgentsVariablesProvider, BuiltinAllSecretVariablesProvider, BuiltinAllVariablesProvider, + BuiltinDatasourceVariablesProvider, BuiltinEmbeddingsVariablesProvider, BuiltinFlowVariablesProvider, + BuiltinKnowledgeSpacesVariablesProvider, BuiltinLLMVariablesProvider, BuiltinNodeVariablesProvider, ) @@ -499,4 +559,7 @@ def init_endpoints(system_app: SystemApp) -> None: system_app.register(BuiltinAllSecretVariablesProvider) system_app.register(BuiltinLLMVariablesProvider) system_app.register(BuiltinEmbeddingsVariablesProvider) + system_app.register(BuiltinDatasourceVariablesProvider) + system_app.register(BuiltinAgentsVariablesProvider) + system_app.register(BuiltinKnowledgeSpacesVariablesProvider) global_system_app = system_app diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index cf82de982..6053dd885 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -2,7 +2,11 @@ from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core.awel import CommonLLMHttpRequestBody -from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest +from dbgpt.core.awel.flow.flow_factory import ( + FlowPanel, + VariablesRequest, + _VariablesRequestBase, +) from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest from ..config import SERVE_APP_NAME_HUMP @@ -28,6 +32,13 @@ class VariablesResponse(VariablesRequest): ) +class VariablesKeyResponse(_VariablesRequestBase): + """Variables Key response model. + + Just include the key, for select options in the frontend. + """ + + class RefreshNodeRequest(BaseModel): """Flow response model""" diff --git a/dbgpt/serve/flow/api/variables_provider.py b/dbgpt/serve/flow/api/variables_provider.py index 4728f80e6..27ed63bf5 100644 --- a/dbgpt/serve/flow/api/variables_provider.py +++ b/dbgpt/serve/flow/api/variables_provider.py @@ -1,9 +1,12 @@ from typing import List, Literal, Optional from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_AGENTS, + BUILTIN_VARIABLES_CORE_DATASOURCES, BUILTIN_VARIABLES_CORE_EMBEDDINGS, BUILTIN_VARIABLES_CORE_FLOW_NODES, BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES, BUILTIN_VARIABLES_CORE_LLMS, BUILTIN_VARIABLES_CORE_SECRETS, BUILTIN_VARIABLES_CORE_VARIABLES, @@ -54,6 +57,7 @@ def get_variables( scope_key=scope_key, sys_code=sys_code, user_name=user_name, + description=flow.description, ) ) return variables @@ -91,6 +95,7 @@ def get_variables( scope_key=scope_key, sys_code=sys_code, user_name=user_name, + description=metadata.get("description"), ) ) return variables @@ -122,10 +127,14 @@ def _get_variables_from_db( name=var.name, label=var.label, value=var.value, + category=var.category, + value_type=var.value_type, scope=scope, scope_key=scope_key, sys_code=sys_code, user_name=user_name, + enabled=1 if var.enabled else 0, + description=var.description, ) ) return variables @@ -258,3 +267,128 @@ async def async_get_variables( return await self._get_models( key, scope, scope_key, sys_code, user_name, "text2vec" ) + + +class BuiltinDatasourceVariablesProvider(BuiltinVariablesProvider): + """Builtin datasource variables provider. + + Provide all datasource variables by variables "${dbgpt.core.datasource}" + """ + + name = BUILTIN_VARIABLES_CORE_DATASOURCES + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.serve.datasource.service.service import ( + DatasourceServeResponse, + Service, + ) + + all_datasource: List[DatasourceServeResponse] = Service.get_instance( + self.system_app + ).list() + + variables = [] + for datasource in all_datasource: + label = f"[{datasource.db_type}]{datasource.db_name}" + variables.append( + StorageVariables( + key=key, + name=datasource.db_name, + label=label, + value=datasource.db_name, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + description=datasource.comment, + ) + ) + return variables + + +class BuiltinAgentsVariablesProvider(BuiltinVariablesProvider): + """Builtin agents variables provider. + + Provide all agents variables by variables "${dbgpt.core.agent.agents}" + """ + + name = BUILTIN_VARIABLES_CORE_AGENTS + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.agent.core.agent_manage import get_agent_manager + + agent_manager = get_agent_manager(self.system_app) + agents = agent_manager.list_agents() + variables = [] + for agent in agents: + variables.append( + StorageVariables( + key=key, + name=agent["name"], + label=agent["desc"], + value=agent["name"], + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + description=agent["desc"], + ) + ) + return variables + + +class BuiltinKnowledgeSpacesVariablesProvider(BuiltinVariablesProvider): + """Builtin knowledge variables provider. + + Provide all knowledge variables by variables "${dbgpt.core.knowledge_spaces}" + """ + + name = BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.serve.rag.service.service import Service, SpaceServeRequest + + # TODO: Query with user_name and sys_code + knowledge_list = Service.get_instance(self.system_app).get_list( + SpaceServeRequest() + ) + variables = [] + for k in knowledge_list: + variables.append( + StorageVariables( + key=key, + name=k.name, + label=k.name, + value=k.name, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + description=k.desc, + ) + ) + return variables diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index cc9c00341..30a8a06c0 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -230,7 +230,7 @@ def load_dag_from_dbgpts(self, is_first_load: bool = False): continue # Set state to DEPLOYED flow.state = State.DEPLOYED - exist_inst = self.get({"name": flow.name}) + exist_inst = self.dao.get_one({"name": flow.name}) if not exist_inst: self.create_and_save_dag(flow, save_failed_flow=True) elif is_first_load or exist_inst.state != State.RUNNING: diff --git a/dbgpt/serve/flow/service/variables_service.py b/dbgpt/serve/flow/service/variables_service.py index 09e2a16b0..fbb4cc9b9 100644 --- a/dbgpt/serve/flow/service/variables_service.py +++ b/dbgpt/serve/flow/service/variables_service.py @@ -1,10 +1,25 @@ from typing import List, Optional from dbgpt import SystemApp -from dbgpt.core.interface.variables import StorageVariables, VariablesProvider -from dbgpt.serve.core import BaseService +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_AGENTS, + BUILTIN_VARIABLES_CORE_DATASOURCES, + BUILTIN_VARIABLES_CORE_EMBEDDINGS, + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES, + BUILTIN_VARIABLES_CORE_LLMS, + BUILTIN_VARIABLES_CORE_RERANKERS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, + StorageVariables, + VariablesProvider, +) +from dbgpt.serve.core import BaseService, blocking_func_to_async +from dbgpt.util import PaginationResult +from dbgpt.util.i18n_utils import _ -from ..api.schemas import VariablesRequest, VariablesResponse +from ..api.schemas import VariablesKeyResponse, VariablesRequest, VariablesResponse from ..config import ( SERVE_CONFIG_KEY_PREFIX, SERVE_VARIABLES_SERVICE_COMPONENT_NAME, @@ -12,6 +27,93 @@ ) from ..models.models import VariablesDao, VariablesEntity +BUILTIN_VARIABLES = [ + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_FLOWS, + label=_("All AWEL Flows"), + description=_("Fetch all AWEL flows in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_FLOW_NODES, + label=_("All AWEL Flow Nodes"), + description=_("Fetch all AWEL flow nodes in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_VARIABLES, + label=_("All Variables"), + description=_("Fetch all variables in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_SECRETS, + label=_("All Secrets"), + description=_("Fetch all secrets in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_LLMS, + label=_("All LLMs"), + description=_("Fetch all LLMs in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_EMBEDDINGS, + label=_("All Embeddings"), + description=_("Fetch all embeddings models in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_RERANKERS, + label=_("All Rerankers"), + description=_("Fetch all rerankers in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_DATASOURCES, + label=_("All Data Sources"), + description=_("Fetch all data sources in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_AGENTS, + label=_("All Agents"), + description=_("Fetch all agents in the system"), + value_type="str", + category="common", + scope="global", + ), + VariablesKeyResponse( + key=BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES, + label=_("All Knowledge Spaces"), + description=_("Fetch all knowledge spaces in the system"), + value_type="str", + category="common", + scope="global", + ), +] + + +def _is_builtin_variable(key: str) -> bool: + return key in [v.key for v in BUILTIN_VARIABLES] + class VariablesService( BaseService[VariablesEntity, VariablesRequest, VariablesResponse] @@ -148,5 +250,119 @@ def update(self, _: int, request: VariablesRequest) -> VariablesResponse: return self.dao.get_one(query) def list_all_variables(self, category: str = "common") -> List[VariablesResponse]: - """List all variables.""" + """List all variables. + + Please note that this method will return all variables in the system, it may + be a large list. + """ return self.dao.get_list({"enabled": True, "category": category}) + + def list_keys( + self, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + category: Optional[str] = None, + ) -> List[VariablesKeyResponse]: + """List all keys.""" + results = [] + + # TODO: More high performance way to get the keys + all_db_variables = self.dao.get_list( + { + "enabled": True, + "category": category, + "user_name": user_name, + "sys_code": sys_code, + } + ) + if not user_name: + # Only return the keys that are not user specific + all_db_variables = [v for v in all_db_variables if not v.user_name] + if not sys_code: + # Only return the keys that are not system specific + all_db_variables = [v for v in all_db_variables if not v.sys_code] + key_to_db_variable = {} + for db_variable in all_db_variables: + key = db_variable.key + if key not in key_to_db_variable: + key_to_db_variable[key] = db_variable + + # Append all builtin variables to the results + results.extend(BUILTIN_VARIABLES) + + # Append all db variables to the results + for key, db_variable in key_to_db_variable.items(): + results.append( + VariablesKeyResponse( + key=key, + label=db_variable.label, + description=db_variable.description, + value_type=db_variable.value_type, + category=db_variable.category, + scope=db_variable.scope, + scope_key=db_variable.scope_key, + ) + ) + return results + + async def get_list_by_page( + self, + key: str, + scope: Optional[str] = None, + scope_key: Optional[str] = None, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + page: int = 1, + page_size: int = 20, + ) -> PaginationResult[VariablesResponse]: + """Get a list of variables by page.""" + if not _is_builtin_variable(key): + query = { + "key": key, + "scope": scope, + "scope_key": scope_key, + "user_name": user_name, + "sys_code": sys_code, + } + return await blocking_func_to_async( + self._system_app, + self.dao.get_list_page, + query, + page, + page_size, + desc_order_column="gmt_modified", + ) + else: + variables: List[ + StorageVariables + ] = await self.variables_provider.async_get_variables( + key=key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + result_variables = [] + for entity in variables: + result_variables.append( + VariablesResponse( + id=-1, + key=entity.key, + name=entity.name, + label=entity.label, + value=entity.value, + value_type=entity.value_type, + category=entity.category, + scope=entity.scope, + scope_key=entity.scope_key, + enabled=True if entity.enabled == 1 else False, + user_name=entity.user_name, + sys_code=entity.sys_code, + description=entity.description, + ) + ) + return PaginationResult.build_from_all( + result_variables, + page, + page_size, + ) diff --git a/dbgpt/util/pagination_utils.py b/dbgpt/util/pagination_utils.py index f8c20ccd9..5b67333c6 100644 --- a/dbgpt/util/pagination_utils.py +++ b/dbgpt/util/pagination_utils.py @@ -15,3 +15,29 @@ class PaginationResult(BaseModel, Generic[T]): total_pages: int = Field(..., description="total number of pages") page: int = Field(..., description="Current page number") page_size: int = Field(..., description="Number of items per page") + + @classmethod + def build_from_all( + cls, all_items: List[T], page: int, page_size: int + ) -> "PaginationResult[T]": + """Build a pagination result from all items""" + if page < 1: + page = 1 + if page_size < 1: + page_size = 1 + total_count = len(all_items) + total_pages = ( + (total_count + page_size - 1) // page_size if total_count > 0 else 0 + ) + page = max(1, min(page, total_pages)) if total_pages > 0 else 0 + start_index = (page - 1) * page_size if page > 0 else 0 + end_index = min(start_index + page_size, total_count) + items = all_items[start_index:end_index] + + return cls( + items=items, + total_count=total_count, + total_pages=total_pages, + page=page, + page_size=page_size, + ) diff --git a/dbgpt/util/tests/test_pagination_utils.py b/dbgpt/util/tests/test_pagination_utils.py new file mode 100644 index 000000000..d0d2132c5 --- /dev/null +++ b/dbgpt/util/tests/test_pagination_utils.py @@ -0,0 +1,84 @@ +from dbgpt.util.pagination_utils import PaginationResult + + +def test_build_from_all_normal_case(): + items = list(range(100)) + result = PaginationResult.build_from_all(items, page=2, page_size=20) + + assert len(result.items) == 20 + assert result.items == list(range(20, 40)) + assert result.total_count == 100 + assert result.total_pages == 5 + assert result.page == 2 + assert result.page_size == 20 + + +def test_build_from_all_empty_list(): + items = [] + result = PaginationResult.build_from_all(items, page=1, page_size=5) + + assert result.items == [] + assert result.total_count == 0 + assert result.total_pages == 0 + assert result.page == 0 + assert result.page_size == 5 + + +def test_build_from_all_last_page(): + items = list(range(95)) + result = PaginationResult.build_from_all(items, page=5, page_size=20) + + assert len(result.items) == 15 + assert result.items == list(range(80, 95)) + assert result.total_count == 95 + assert result.total_pages == 5 + assert result.page == 5 + assert result.page_size == 20 + + +def test_build_from_all_page_out_of_range(): + items = list(range(50)) + result = PaginationResult.build_from_all(items, page=10, page_size=10) + + assert len(result.items) == 10 + assert result.items == list(range(40, 50)) + assert result.total_count == 50 + assert result.total_pages == 5 + assert result.page == 5 + assert result.page_size == 10 + + +def test_build_from_all_page_zero(): + items = list(range(50)) + result = PaginationResult.build_from_all(items, page=0, page_size=10) + + assert len(result.items) == 10 + assert result.items == list(range(0, 10)) + assert result.total_count == 50 + assert result.total_pages == 5 + assert result.page == 1 + assert result.page_size == 10 + + +def test_build_from_all_negative_page(): + items = list(range(50)) + result = PaginationResult.build_from_all(items, page=-1, page_size=10) + + assert len(result.items) == 10 + assert result.items == list(range(0, 10)) + assert result.total_count == 50 + assert result.total_pages == 5 + assert result.page == 1 + assert result.page_size == 10 + + +def test_build_from_all_page_size_larger_than_total(): + items = list(range(50)) + result = PaginationResult.build_from_all(items, page=1, page_size=100) + + assert len(result.items) == 50 + assert result.items == list(range(50)) + assert result.total_count == 50 + assert result.total_pages == 1 + assert result.page == 1 + assert result.page_size == 100 From 0219f5733b4b648ac1f5aeec6244429e982a2682 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 29 Aug 2024 23:07:51 +0800 Subject: [PATCH 12/60] feat: Support query file metadatas --- dbgpt/serve/file/api/endpoints.py | 77 ++++++++++++++++++++++++++++- dbgpt/serve/file/api/schemas.py | 48 +++++++++++++++++- dbgpt/serve/file/service/service.py | 39 ++++++++++++++- 3 files changed, 159 insertions(+), 5 deletions(-) diff --git a/dbgpt/serve/file/api/endpoints.py b/dbgpt/serve/file/api/endpoints.py index 26bbb9673..d5b65bc54 100644 --- a/dbgpt/serve/file/api/endpoints.py +++ b/dbgpt/serve/file/api/endpoints.py @@ -1,3 +1,4 @@ +import asyncio import logging from functools import cache from typing import List, Optional @@ -13,7 +14,13 @@ from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from .schemas import ServeRequest, ServerResponse, UploadFileResponse +from .schemas import ( + FileMetadataBatchRequest, + FileMetadataResponse, + ServeRequest, + ServerResponse, + UploadFileResponse, +) router = APIRouter() logger = logging.getLogger(__name__) @@ -162,6 +169,74 @@ async def delete_file( return Result.succ(None) +@router.get( + "/files/metadata", + response_model=Result[FileMetadataResponse], + dependencies=[Depends(check_api_key)], +) +async def get_file_metadata( + uri: Optional[str] = Query(None, description="File URI"), + bucket: Optional[str] = Query(None, description="Bucket name"), + file_id: Optional[str] = Query(None, description="File ID"), + service: Service = Depends(get_service), +) -> Result[FileMetadataResponse]: + """Get file metadata by URI or by bucket and file_id.""" + if not uri and not (bucket and file_id): + raise HTTPException( + status_code=400, + detail="Either uri or (bucket and file_id) must be provided", + ) + + metadata = await blocking_func_to_async( + global_system_app, service.get_file_metadata, uri, bucket, file_id + ) + return Result.succ(metadata) + + +@router.post( + "/files/metadata/batch", + response_model=Result[List[FileMetadataResponse]], + dependencies=[Depends(check_api_key)], +) +async def get_files_metadata_batch( + request: FileMetadataBatchRequest, service: Service = Depends(get_service) +) -> Result[List[FileMetadataResponse]]: + """Get metadata for multiple files by URIs or bucket and file_id pairs.""" + if not request.uris and not request.bucket_file_pairs: + raise HTTPException( + status_code=400, + detail="Either uris or bucket_file_pairs must be provided", + ) + + batch_req = [] + if request.uris: + for uri in request.uris: + batch_req.append((uri, None, None)) + elif request.bucket_file_pairs: + for pair in request.bucket_file_pairs: + batch_req.append((None, pair.bucket, pair.file_id)) + else: + raise HTTPException( + status_code=400, + detail="Either uris or bucket_file_pairs must be provided", + ) + + batch_req_tasks = [ + blocking_func_to_async( + global_system_app, service.get_file_metadata, uri, bucket, file_id + ) + for uri, bucket, file_id in batch_req + ] + + metadata_list = await asyncio.gather(*batch_req_tasks) + if not metadata_list: + raise HTTPException( + status_code=404, + detail="File metadata not found", + ) + return Result.succ(metadata_list) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" global global_system_app diff --git a/dbgpt/serve/file/api/schemas.py b/dbgpt/serve/file/api/schemas.py index 911f71db3..bd8b3bbf2 100644 --- a/dbgpt/serve/file/api/schemas.py +++ b/dbgpt/serve/file/api/schemas.py @@ -1,7 +1,13 @@ # Define your Pydantic schemas here -from typing import Any, Dict +from typing import Any, Dict, List, Optional -from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict +from dbgpt._private.pydantic import ( + BaseModel, + ConfigDict, + Field, + model_to_dict, + model_validator, +) from ..config import SERVE_APP_NAME_HUMP @@ -41,3 +47,41 @@ class UploadFileResponse(BaseModel): def to_dict(self, **kwargs) -> Dict[str, Any]: """Convert the model to a dictionary""" return model_to_dict(self, **kwargs) + + +class _BucketFilePair(BaseModel): + """Bucket file pair model""" + + bucket: str = Field(..., title="The bucket of the file") + file_id: str = Field(..., title="The ID of the file") + + +class FileMetadataBatchRequest(BaseModel): + """File metadata batch request model""" + + uris: Optional[List[str]] = Field(None, title="The URIs of the files") + bucket_file_pairs: Optional[List[_BucketFilePair]] = Field( + None, title="The bucket file pairs" + ) + + @model_validator(mode="after") + def check_uris_or_bucket_file_pairs(self): + # Check if either uris or bucket_file_pairs is provided + if not (self.uris or self.bucket_file_pairs): + raise ValueError("Either uris or bucket_file_pairs must be provided") + # Check only one of uris or bucket_file_pairs is provided + if self.uris and self.bucket_file_pairs: + raise ValueError("Only one of uris or bucket_file_pairs can be provided") + return self + + +class FileMetadataResponse(BaseModel): + """File metadata model""" + + file_name: str = Field(..., title="The name of the file") + file_id: str = Field(..., title="The ID of the file") + bucket: str = Field(..., title="The bucket of the file") + uri: str = Field(..., title="The URI of the file") + file_size: int = Field(..., title="The size of the file") + user_name: Optional[str] = Field(None, title="The user name") + sys_code: Optional[str] = Field(None, title="The system code") diff --git a/dbgpt/serve/file/service/service.py b/dbgpt/serve/file/service/service.py index 13e8b6225..85940ed35 100644 --- a/dbgpt/serve/file/service/service.py +++ b/dbgpt/serve/file/service/service.py @@ -1,7 +1,7 @@ import logging from typing import BinaryIO, List, Optional, Tuple -from fastapi import UploadFile +from fastapi import HTTPException, UploadFile from dbgpt.component import BaseComponent, SystemApp from dbgpt.core.interface.file import FileMetadata, FileStorageClient, FileStorageURI @@ -10,7 +10,12 @@ from dbgpt.util.pagination_utils import PaginationResult from dbgpt.util.tracer import root_tracer, trace -from ..api.schemas import ServeRequest, ServerResponse, UploadFileResponse +from ..api.schemas import ( + FileMetadataResponse, + ServeRequest, + ServerResponse, + UploadFileResponse, +) from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..models.models import ServeDao, ServeEntity @@ -117,3 +122,33 @@ def download_file(self, bucket: str, file_id: str) -> Tuple[BinaryIO, FileMetada def delete_file(self, bucket: str, file_id: str) -> None: """Delete a file by file_id.""" self.file_storage_client.delete_file_by_id(bucket, file_id) + + def get_file_metadata( + self, + uri: Optional[str] = None, + bucket: Optional[str] = None, + file_id: Optional[str] = None, + ) -> Optional[FileMetadataResponse]: + """Get the metadata of a file by file_id.""" + if uri: + parsed_uri = FileStorageURI.parse(uri) + bucket, file_id = parsed_uri.bucket, parsed_uri.file_id + if not (bucket and file_id): + raise ValueError("Either uri or bucket and file_id must be provided.") + metadata = self.file_storage_client.storage_system.get_file_metadata( + bucket, file_id + ) + if not metadata: + raise HTTPException( + status_code=404, + detail=f"File metadata not found: bucket={bucket}, file_id={file_id}, uri={uri}", + ) + return FileMetadataResponse( + file_name=metadata.file_name, + file_id=metadata.file_id, + bucket=metadata.bucket, + uri=metadata.uri, + file_size=metadata.file_size, + user_name=metadata.user_name, + sys_code=metadata.sys_code, + ) From e9155f0c313d73b2be5dcaff0509a3d2c1004745 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 30 Aug 2024 07:24:22 +0800 Subject: [PATCH 13/60] feat: Support dynamic parameters --- dbgpt/core/awel/flow/base.py | 107 +++++++++++----- examples/awel/awel_flow_ui_components.py | 155 ++++++++++++++++++++++- 2 files changed, 228 insertions(+), 34 deletions(-) diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 4e691ed08..99aa77c8b 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -36,6 +36,8 @@ } _BASIC_TYPES = [str, int, float, bool, dict, list, set] +_DYNAMIC_PARAMETER_TYPES = [str, int, float, bool] +DefaultParameterType = Union[str, int, float, bool, None] T = TypeVar("T", bound="ViewMixin") TM = TypeVar("TM", bound="TypeMetadata") @@ -292,9 +294,6 @@ def get_category(cls, value: Type[Any]) -> "ParameterCategory": return cls.RESOURCER -DefaultParameterType = Union[str, int, float, bool, None] - - class TypeMetadata(BaseModel): """The metadata of the type.""" @@ -313,7 +312,23 @@ def new(self: TM) -> TM: return self.__class__(**self.model_dump(exclude_defaults=True)) -class Parameter(TypeMetadata, Serializable): +class BaseDynamic(BaseModel): + """The base dynamic field.""" + + dynamic: bool = Field( + default=False, + description="Whether current field is dynamic", + examples=[True, False], + ) + dynamic_minimum: int = Field( + default=0, + description="The minimum count of the dynamic field, only valid when dynamic is" + " True", + examples=[0, 1, 2], + ) + + +class Parameter(BaseDynamic, TypeMetadata, Serializable): """Parameter for build operator.""" label: str = Field( @@ -332,11 +347,6 @@ class Parameter(TypeMetadata, Serializable): description="The category of the parameter", examples=["common", "resource"], ) - # resource_category: Optional[str] = Field( - # default=None, - # description="The category of the resource, just for resource type", - # examples=["llm_client", "common"], - # ) resource_type: ResourceType = Field( default=ResourceType.INSTANCE, description="The type of the resource, just for resource type", @@ -389,6 +399,17 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: values[k] = handled_v return values + @model_validator(mode="after") + def check_parameters(self) -> "Parameter": + """Check the parameters.""" + if self.dynamic and not self.is_list: + raise FlowMetadataException("Dynamic parameter must be list.") + if self.dynamic and self.dynamic_minimum < 0: + raise FlowMetadataException( + "Dynamic minimum must be greater then or equal to 0." + ) + return self + @classmethod def _covert_to_real_type(cls, type_cls: str, v: Any, is_list: bool) -> Any: def _parse_single_value(vv: Any) -> Any: @@ -450,6 +471,8 @@ def build_from( description: Optional[str] = None, options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None, resource_type: ResourceType = ResourceType.INSTANCE, + dynamic: bool = False, + dynamic_minimum: int = 0, alias: Optional[List[str]] = None, ui: Optional[UIComponent] = None, ): @@ -461,6 +484,8 @@ def build_from( raise ValueError(f"Default value is missing for optional parameter {name}.") if not optional: default = None + if dynamic and type not in _DYNAMIC_PARAMETER_TYPES: + raise ValueError("Dynamic parameter must be str, int, float or bool.") return cls( label=label, name=name, @@ -474,6 +499,8 @@ def build_from( placeholder=placeholder, description=description or label, options=options, + dynamic=dynamic, + dynamic_minimum=dynamic_minimum, alias=alias, ui=ui, ) @@ -635,6 +662,11 @@ class BaseResource(Serializable, BaseModel): description="The label to display in UI", examples=["LLM Operator", "OpenAI LLM Client"], ) + custom_label: Optional[str] = Field( + None, + description="The custom label to display in UI", + examples=["LLM Operator", "OpenAI LLM Client"], + ) name: str = Field( ..., description="The name of the operator", @@ -668,7 +700,7 @@ class IOFiledType(str, Enum): LIST = "list" -class IOField(Resource): +class IOField(BaseDynamic, Resource): """The input or output field of the operator.""" is_list: bool = Field( @@ -676,17 +708,6 @@ class IOField(Resource): description="Whether current field is list", examples=[True, False], ) - dynamic: bool = Field( - default=False, - description="Whether current field is dynamic", - examples=[True, False], - ) - dynamic_minimum: int = Field( - default=0, - description="The minimum count of the dynamic field, only valid when dynamic is" - " True", - examples=[0, 1, 2], - ) mappers: Optional[List[str]] = Field( default=None, description="The mappers of the field, transform the field to the target type", @@ -724,18 +745,6 @@ def build_from( mappers=mappers_cls, ) - @model_validator(mode="before") - @classmethod - def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Pre fill the metadata.""" - if not isinstance(values, dict): - return values - if "dynamic" not in values: - values["dynamic"] = False - if "dynamic_minimum" not in values: - values["dynamic_minimum"] = 0 - return values - class BaseMetadata(BaseResource): """The base metadata.""" @@ -1137,6 +1146,38 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["outputs"] = new_outputs return values + @model_validator(mode="after") + def check_metadata(self) -> "ViewMetadata": + """Check the metadata.""" + if self.inputs: + for field in self.inputs: + if field.mappers: + raise ValueError("Input field can't have mappers.") + dyn_cnt, is_last_field_dynamic = 0, False + for field in self.inputs: + if field.dynamic: + dyn_cnt += 1 + is_last_field_dynamic = True + else: + if is_last_field_dynamic: + raise ValueError("Dynamic field input must be the last field.") + is_last_field_dynamic = False + if dyn_cnt > 1: + raise ValueError("Only one dynamic input field is allowed.") + if self.outputs: + dyn_cnt, is_last_field_dynamic = 0, False + for field in self.outputs: + if field.dynamic: + dyn_cnt += 1 + is_last_field_dynamic = True + else: + if is_last_field_dynamic: + raise ValueError("Dynamic field output must be the last field.") + is_last_field_dynamic = False + if dyn_cnt > 1: + raise ValueError("Only one dynamic output field is allowed.") + return self + def get_operator_key(self) -> str: """Get the operator key.""" if not self.flow_type: diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index 2db9607ea..ce411a79d 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -4,7 +4,7 @@ import logging from typing import Any, Dict, List, Optional -from dbgpt.core.awel import MapOperator +from dbgpt.core.awel import JoinOperator, MapOperator from dbgpt.core.awel.flow import ( FunctionDynamicOptions, IOField, @@ -1243,3 +1243,156 @@ def execute_code_blocks(self, code_blocks): if exitcode != 0: return exitcode, logs_all return exitcode, logs_all + + +class ExampleFlowDynamicParametersOperator(MapOperator[str, str]): + """An example flow operator that includes dynamic parameters.""" + + metadata = ViewMetadata( + label="Example Dynamic Parameters Operator", + name="example_dynamic_parameters_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes dynamic parameters.", + parameters=[ + Parameter.build_from( + "Dynamic String", + "dynamic_1", + type=str, + is_list=True, + placeholder="Please input the dynamic parameter", + description="The dynamic parameter you want to use, you can add more, " + "at least 1 parameter.", + dynamic=True, + dynamic_minimum=1, + ui=ui.UIInput(), + ), + Parameter.build_from( + "Dynamic Integer", + "dynamic_2", + type=int, + is_list=True, + placeholder="Please input the dynamic parameter", + description="The dynamic parameter you want to use, you can add more, " + "at least 0 parameter.", + dynamic=True, + dynamic_minimum=0, + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Dynamic", + "dynamic", + str, + description="User's selected dynamic.", + ), + ], + ) + + def __init__(self, dynamic_1: List[str], dynamic_2: List[int], **kwargs): + super().__init__(**kwargs) + if not dynamic_1: + raise ValueError("The dynamic string is empty.") + self.dynamic_1 = dynamic_1 + self.dynamic_2 = dynamic_2 + + async def map(self, user_name: str) -> str: + """Map the user name to the dynamic.""" + return "Your name is %s, and your dynamic is %s." % ( + user_name, + f"dynamic_1: {self.dynamic_1}, dynamic_2: {self.dynamic_2}", + ) + + +class ExampleFlowDynamicOutputsOperator(MapOperator[str, str]): + """An example flow operator that includes dynamic outputs.""" + + metadata = ViewMetadata( + label="Example Dynamic Outputs Operator", + name="example_dynamic_outputs_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes dynamic outputs.", + parameters=[], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Dynamic", + "dynamic", + str, + description="User's selected dynamic.", + dynamic=True, + dynamic_minimum=1, + ), + ], + ) + + async def map(self, user_name: str) -> str: + """Map the user name to the dynamic.""" + return "Your name is %s, this operator has dynamic outputs." % user_name + + +class ExampleFlowDynamicInputsOperator(JoinOperator[str]): + """An example flow operator that includes dynamic inputs.""" + + metadata = ViewMetadata( + label="Example Dynamic Inputs Operator", + name="example_dynamic_inputs_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes dynamic inputs.", + parameters=[], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + IOField.build_from( + "Other Inputs", + "other_inputs", + str, + description="Other inputs.", + dynamic=True, + dynamic_minimum=0, + ), + ], + outputs=[ + IOField.build_from( + "Dynamic", + "dynamic", + str, + description="User's selected dynamic.", + ), + ], + ) + + def __init__(self, **kwargs): + super().__init__(combine_function=self.join, **kwargs) + + async def join(self, user_name: str, *other_inputs: str) -> str: + """Map the user name to the dynamic.""" + if not other_inputs: + dyn_inputs = ["You have no other inputs."] + else: + dyn_inputs = [ + f"Input {i}: {input_data}" for i, input_data in enumerate(other_inputs) + ] + dyn_str = "\n".join(dyn_inputs) + return "Your name is %s, and your dynamic is %s." % ( + user_name, + f"other_inputs:\n{dyn_str}", + ) From d15201780086531715194bab8d3c69ee26e10d43 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 30 Aug 2024 11:27:02 +0800 Subject: [PATCH 14/60] feat: Support rag flow template --- dbgpt/serve/dbgpts/__init__.py | 0 dbgpt/serve/flow/api/endpoints.py | 29 +- dbgpt/serve/flow/service/service.py | 59 + .../en/rag-chat-awel-flow-template.json | 1088 +++++++++++++++++ 4 files changed, 1174 insertions(+), 2 deletions(-) create mode 100644 dbgpt/serve/dbgpts/__init__.py create mode 100644 dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json diff --git a/dbgpt/serve/dbgpts/__init__.py b/dbgpt/serve/dbgpts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index a985b9ab1..5d3ebd75d 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -14,7 +14,7 @@ from dbgpt.util import PaginationResult from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig -from ..service.service import Service +from ..service.service import Service, _parse_flow_template_from_json from ..service.variables_service import VariablesService from .schemas import ( FlowDebugRequest, @@ -517,7 +517,7 @@ async def import_flow( raise HTTPException( status_code=400, detail="invalid json file, missing 'flow' key" ) - flow = ServeRequest.parse_obj(json_dict["flow"]) + flow = _parse_flow_template_from_json(json_dict["flow"]) elif file_extension == "zip": from ..service.share_utils import _parse_flow_from_zip_file @@ -536,6 +536,31 @@ async def import_flow( return Result.succ(flow) +@router.get( + "/flow/templates", + response_model=Result[PaginationResult[ServerResponse]], + dependencies=[Depends(check_api_key)], +) +async def query_flow_templates( + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + page: int = Query(default=1, description="current page"), + page_size: int = Query(default=20, description="page size"), + service: Service = Depends(get_service), +) -> Result[PaginationResult[ServerResponse]]: + """Query Flow templates.""" + + res = await blocking_func_to_async( + global_system_app, + service.get_flow_templates, + user_name, + sys_code, + page, + page_size, + ) + return Result.succ(res) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" from .variables_provider import ( diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 30a8a06c0..54d07d49a 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -1,5 +1,6 @@ import json import logging +import os from typing import AsyncIterator, List, Optional, cast import schedule @@ -399,6 +400,47 @@ def get_list_by_page( item.metadata = metadata.to_dict() return page_result + def get_flow_templates( + self, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + page: int = 1, + page_size: int = 20, + ) -> PaginationResult[ServerResponse]: + """Get a list of Flow templates + + Args: + user_name (Optional[str]): The user name + sys_code (Optional[str]): The system code + page (int): The page number + page_size (int): The page size + Returns: + List[ServerResponse]: The response + """ + local_file_templates = self._get_flow_templates_from_files() + return PaginationResult.build_from_all(local_file_templates, page, page_size) + + def _get_flow_templates_from_files(self) -> List[ServerResponse]: + """Get a list of Flow templates from files""" + user_lang = self._system_app.config.get_current_lang(default="en") + # List files in current directory + parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + template_dir = os.path.join(parent_dir, "templates", user_lang) + default_template_dir = os.path.join(parent_dir, "templates", "en") + if not os.path.exists(template_dir): + template_dir = default_template_dir + templates = [] + for root, _, files in os.walk(template_dir): + for file in files: + if file.endswith(".json"): + try: + with open(os.path.join(root, file), "r") as f: + data = json.load(f) + templates.append(_parse_flow_template_from_json(data)) + except Exception as e: + logger.warning(f"Load template {file} error: {str(e)}") + return templates + async def chat_stream_flow_str( self, flow_uid: str, request: CommonLLMHttpRequestBody ) -> AsyncIterator[str]: @@ -638,3 +680,20 @@ async def _wrapper_chat_stream_flow_str( break else: yield f"data:{text}\n\n" + + +def _parse_flow_template_from_json(json_dict: dict) -> ServerResponse: + """Parse the flow from json + + Args: + json_dict (dict): The json dict + + Returns: + ServerResponse: The flow + """ + flow_json = json_dict["flow"] + flow_json["editable"] = False + del flow_json["uid"] + flow_json["state"] = State.INITIALIZING + flow_json["dag_id"] = None + return ServerResponse(**flow_json) diff --git a/dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json b/dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json new file mode 100644 index 000000000..60ff5c911 --- /dev/null +++ b/dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json @@ -0,0 +1,1088 @@ +{ + "flow": { + "uid": "21eb87d5-b63a-4f41-b2aa-28d01033344d", + "label": "RAG Chat AWEL flow template", + "name": "rag_chat_awel_flow_template", + "flow_category": "chat_flow", + "description": "An example of a RAG chat AWEL flow.", + "state": "running", + "error_message": "", + "source": "DBGPT-WEB", + "source_url": null, + "version": "0.1.1", + "define_type": "json", + "editable": true, + "user_name": null, + "sys_code": null, + "dag_id": "flow_dag_rag_chat_awel_flow_template_21eb87d5-b63a-4f41-b2aa-28d01033344d", + "gmt_created": "2024-08-30 10:48:56", + "gmt_modified": "2024-08-30 10:48:56", + "metadata": { + "sse_output": true, + "streaming_output": true, + "tags": {}, + "triggers": [ + { + "trigger_type": "http", + "path": "/api/v1/awel/trigger/templates/flow_dag_rag_chat_awel_flow_template_21eb87d5-b63a-4f41-b2aa-28d01033344d", + "methods": [ + "POST" + ], + "trigger_mode": "chat" + } + ] + }, + "variables": null, + "authors": null, + "flow_data": { + "edges": [ + { + "source": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0", + "source_order": 0, + "target": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "target_order": 0, + "id": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0|operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "source_handle": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0|outputs|0", + "target_handle": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|inputs|0", + "type": "buttonedge" + }, + { + "source": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0", + "source_order": 1, + "target": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "target_order": 0, + "id": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0|operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "source_handle": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0|outputs|1", + "target_handle": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0|inputs|0", + "type": "buttonedge" + }, + { + "source": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "source_order": 0, + "target": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "target_order": 1, + "id": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0|operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "source_handle": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0|outputs|0", + "target_handle": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|inputs|1", + "type": "buttonedge" + }, + { + "source": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0", + "source_order": 0, + "target": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "target_order": 0, + "id": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0|operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "source_handle": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0|outputs|0", + "target_handle": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|parameters|0", + "type": "buttonedge" + }, + { + "source": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "source_order": 0, + "target": "operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0", + "target_order": 0, + "id": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0", + "source_handle": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0|outputs|0", + "target_handle": "operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0|inputs|0", + "type": "buttonedge" + } + ], + "viewport": { + "x": 900.5986504747431, + "y": 420.90015979869725, + "zoom": 0.6903331247004052 + }, + "nodes": [ + { + "width": 320, + "height": 632, + "id": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0", + "position": { + "x": -1164.0000230376968, + "y": -501.9869760888273, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": -1164.0000230376968, + "y": -501.9869760888273, + "zoom": 0.0 + }, + "data": { + "label": "Common LLM Http Trigger", + "custom_label": null, + "name": "common_llm_http_trigger", + "description": "Trigger your workflow by http request, and parse the request body as a common LLM http body", + "category": "trigger", + "category_label": "Trigger", + "flow_type": "operator", + "icon": null, + "documentation_url": null, + "id": "operator_common_llm_http_trigger___$$___trigger___$$___v1_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0" + }, + "operator_type": "input", + "inputs": [], + "outputs": [ + { + "type_name": "CommonLLMHttpRequestBody", + "type_cls": "dbgpt.core.awel.trigger.http_trigger.CommonLLMHttpRequestBody", + "label": "Request Body", + "custom_label": null, + "name": "request_body", + "description": "The request body of the API endpoint, parse as a common LLM http body", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "label": "Request String Messages", + "custom_label": null, + "name": "request_string_messages", + "description": "The request string messages of the API endpoint, parsed from 'messages' field of the request body", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": [ + "dbgpt.core.awel.trigger.http_trigger.CommonLLMHttpTrigger.MessagesOutputMapper" + ] + } + ], + "version": "v1", + "type_name": "CommonLLMHttpTrigger", + "type_cls": "dbgpt.core.awel.trigger.http_trigger.CommonLLMHttpTrigger", + "parameters": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "API Endpoint", + "name": "endpoint", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "/example/{dag_id}", + "placeholder": null, + "description": "The API endpoint", + "value": "/templates/{dag_id}", + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Http Methods", + "name": "methods", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "POST", + "placeholder": null, + "description": "The methods of the API endpoint", + "value": null, + "options": [ + { + "label": "HTTP Method PUT", + "name": "http_put", + "value": "PUT", + "children": null + }, + { + "label": "HTTP Method POST", + "name": "http_post", + "value": "POST", + "children": null + } + ] + }, + { + "type_name": "bool", + "type_cls": "builtins.bool", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Streaming Response", + "name": "streaming_response", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": false, + "placeholder": null, + "description": "Whether the response is streaming", + "value": false, + "options": null + }, + { + "type_name": "BaseHttpBody", + "type_cls": "dbgpt.core.awel.trigger.http_trigger.BaseHttpBody", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Http Response Body", + "name": "http_response_body", + "is_list": false, + "category": "resource", + "resource_type": "class", + "optional": true, + "default": null, + "placeholder": null, + "description": "The response body of the API endpoint", + "value": null, + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Response Media Type", + "name": "response_media_type", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The response media type", + "value": null, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Http Status Code", + "name": "status_code", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 200, + "placeholder": null, + "description": "The http status code", + "value": null, + "options": null + } + ] + } + }, + { + "width": 320, + "height": 910, + "id": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "position": { + "x": 661.094354143159, + "y": -368.93541722528227, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": 661.094354143159, + "y": -368.93541722528227, + "zoom": 0.0 + }, + "data": { + "label": "Streaming LLM Operator", + "custom_label": null, + "name": "higher_order_streaming_llm_operator", + "description": "High-level streaming LLM operator, supports multi-round conversation (conversation window, token length and no multi-round).", + "category": "llm", + "category_label": "LLM", + "flow_type": "operator", + "icon": null, + "documentation_url": null, + "id": "operator_higher_order_streaming_llm_operator___$$___llm___$$___v1_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0" + }, + "operator_type": "map", + "inputs": [ + { + "type_name": "CommonLLMHttpRequestBody", + "type_cls": "dbgpt.core.awel.trigger.http_trigger.CommonLLMHttpRequestBody", + "label": "Common LLM Request Body", + "custom_label": null, + "name": "common_llm_request_body", + "description": "The common LLM request body.", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + }, + { + "type_name": "HOContextBody", + "type_cls": "dbgpt.app.operators.llm.HOContextBody", + "label": "Extra Context", + "custom_label": null, + "name": "extra_context", + "description": "Extra context for building prompt(Knowledge context, database schema, etc), you can add multiple context.", + "dynamic": true, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + } + ], + "outputs": [ + { + "type_name": "ModelOutput", + "type_cls": "dbgpt.core.interface.llm.ModelOutput", + "label": "Streaming Model Output", + "custom_label": null, + "name": "streaming_model_output", + "description": "The streaming model output.", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": true, + "mappers": null + } + ], + "version": "v1", + "type_name": "HOStreamingLLMOperator", + "type_cls": "dbgpt.app.operators.llm.HOStreamingLLMOperator", + "parameters": [ + { + "type_name": "ChatPromptTemplate", + "type_cls": "dbgpt.core.interface.prompt.ChatPromptTemplate", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Prompt Template", + "name": "prompt_template", + "is_list": false, + "category": "resource", + "resource_type": "instance", + "optional": false, + "default": null, + "placeholder": null, + "description": "The prompt template for the conversation.", + "value": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0", + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Model Name", + "name": "model", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The model name.", + "value": null, + "options": null + }, + { + "type_name": "LLMClient", + "type_cls": "dbgpt.core.interface.llm.LLMClient", + "dynamic": false, + "dynamic_minimum": 0, + "label": "LLM Client", + "name": "llm_client", + "is_list": false, + "category": "resource", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The LLM Client, how to connect to the LLM model, if not provided, it will use the default client deployed by DB-GPT.", + "value": null, + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "History Message Merge Mode", + "name": "history_merge_mode", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "none", + "placeholder": null, + "description": "The history merge mode, supports 'none', 'window' and 'token'. 'none': no history merge, 'window': merge by conversation window, 'token': merge by token length.", + "value": "window", + "options": [ + { + "label": "No History", + "name": "none", + "value": "none", + "children": null + }, + { + "label": "Message Window", + "name": "window", + "value": "window", + "children": null + }, + { + "label": "Token Length", + "name": "token", + "value": "token", + "children": null + } + ], + "ui": { + "refresh": false, + "refresh_depends": null, + "ui_type": "select", + "size": null, + "attr": null + } + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "User Message Key", + "name": "user_message_key", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "user_input", + "placeholder": null, + "description": "The key of the user message in your prompt, default is 'user_input'.", + "value": null, + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "History Key", + "name": "history_key", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The chat history key, with chat history message pass to prompt template, if not provided, it will parse the prompt template to get the key.", + "value": null, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Keep Start Rounds", + "name": "keep_start_rounds", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The start rounds to keep in the chat history.", + "value": 0, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Keep End Rounds", + "name": "keep_end_rounds", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "The end rounds to keep in the chat history.", + "value": 10, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Max Token Limit", + "name": "max_token_limit", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 2048, + "placeholder": null, + "description": "The max token limit to keep in the chat history.", + "value": null, + "options": null + } + ] + } + }, + { + "width": 320, + "height": 774, + "id": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "position": { + "x": -781.3390803520426, + "y": 112.87665693387501, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": -781.3390803520426, + "y": 112.87665693387501, + "zoom": 0.0 + }, + "data": { + "label": "Knowledge Operator", + "custom_label": null, + "name": "higher_order_knowledge_operator", + "description": "Knowledge Operator, retrieve your knowledge(documents) from knowledge space", + "category": "rag", + "category_label": "RAG", + "flow_type": "operator", + "icon": null, + "documentation_url": null, + "id": "operator_higher_order_knowledge_operator___$$___rag___$$___v1_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0" + }, + "operator_type": "map", + "inputs": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "label": "User question", + "custom_label": null, + "name": "query", + "description": "The user question to retrieve the knowledge", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + } + ], + "outputs": [ + { + "type_name": "HOContextBody", + "type_cls": "dbgpt.app.operators.llm.HOContextBody", + "label": "Retrieved context", + "custom_label": null, + "name": "context", + "description": "The retrieved context from the knowledge space", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": false, + "mappers": null + }, + { + "type_name": "Chunk", + "type_cls": "dbgpt.core.interface.knowledge.Chunk", + "label": "Chunks", + "custom_label": null, + "name": "chunks", + "description": "The retrieved chunks from the knowledge space", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": true, + "mappers": [ + "dbgpt.app.operators.rag.HOKnowledgeOperator.ChunkMapper" + ] + } + ], + "version": "v1", + "type_name": "HOKnowledgeOperator", + "type_cls": "dbgpt.app.operators.rag.HOKnowledgeOperator", + "parameters": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Knowledge Space Name", + "name": "knowledge_space", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": false, + "default": null, + "placeholder": null, + "description": "The name of the knowledge space", + "value": "k_cmd2", + "options": [ + { + "label": "k_cmd2", + "name": "k_cmd2", + "value": "k_cmd2", + "children": null + }, + { + "label": "f5", + "name": "f5", + "value": "f5", + "children": null + }, + { + "label": "f4", + "name": "f4", + "value": "f4", + "children": null + }, + { + "label": "t333", + "name": "t333", + "value": "t333", + "children": null + }, + { + "label": "f3", + "name": "f3", + "value": "f3", + "children": null + }, + { + "label": "f1", + "name": "f1", + "value": "f1", + "children": null + }, + { + "label": "sdf", + "name": "sdf", + "value": "sdf", + "children": null + }, + { + "label": "sfsd", + "name": "sfsd", + "value": "sfsd", + "children": null + }, + { + "label": "hello", + "name": "hello", + "value": "hello", + "children": null + }, + { + "label": "k1", + "name": "k1", + "value": "k1", + "children": null + }, + { + "label": "f2", + "name": "f2", + "value": "f2", + "children": null + }, + { + "label": "test_f1", + "name": "test_f1", + "value": "test_f1", + "children": null + }, + { + "label": "SMMF", + "name": "SMMF", + "value": "SMMF", + "children": null + }, + { + "label": "docker_xxx", + "name": "docker_xxx", + "value": "docker_xxx", + "children": null + }, + { + "label": "t2", + "name": "t2", + "value": "t2", + "children": null + }, + { + "label": "t1", + "name": "t1", + "value": "t1", + "children": null + }, + { + "label": "test_graph", + "name": "test_graph", + "value": "test_graph", + "children": null + }, + { + "label": "small", + "name": "small", + "value": "small", + "children": null + }, + { + "label": "ttt", + "name": "ttt", + "value": "ttt", + "children": null + }, + { + "label": "bf", + "name": "bf", + "value": "bf", + "children": null + }, + { + "label": "new_big_file", + "name": "new_big_file", + "value": "new_big_file", + "children": null + }, + { + "label": "test_big_fild", + "name": "test_big_fild", + "value": "test_big_fild", + "children": null + }, + { + "label": "Greenplum", + "name": "Greenplum", + "value": "Greenplum", + "children": null + }, + { + "label": "Mytest", + "name": "Mytest", + "value": "Mytest", + "children": null + }, + { + "label": "dba", + "name": "dba", + "value": "dba", + "children": null + } + ] + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Context Key", + "name": "context", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "context", + "placeholder": null, + "description": "The key of the context, it will be used in building the prompt", + "value": null, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Top K", + "name": "top_k", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 5, + "placeholder": null, + "description": "The number of chunks to retrieve", + "value": null, + "options": null + }, + { + "type_name": "float", + "type_cls": "builtins.float", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Minimum Match Score", + "name": "score_threshold", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 0.3, + "placeholder": null, + "description": "The minimum match score for the retrieved chunks, it will be dropped if the match score is less than the threshold", + "value": null, + "options": null, + "ui": { + "refresh": false, + "refresh_depends": null, + "ui_type": "slider", + "size": null, + "attr": { + "disabled": false, + "min": 0.0, + "max": 1.0, + "step": 0.1 + }, + "show_input": false + } + }, + { + "type_name": "bool", + "type_cls": "builtins.bool", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Reranker Enabled", + "name": "reranker_enabled", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": null, + "placeholder": null, + "description": "Whether to enable the reranker", + "value": null, + "options": null + }, + { + "type_name": "int", + "type_cls": "builtins.int", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Reranker Top K", + "name": "reranker_top_k", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": 3, + "placeholder": null, + "description": "The top k for the reranker", + "value": null, + "options": null + } + ] + } + }, + { + "width": 320, + "height": 884, + "id": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0", + "position": { + "x": 195.5602050169747, + "y": 175.41495969060128, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": 195.5602050169747, + "y": 175.41495969060128, + "zoom": 0.0 + }, + "data": { + "type_name": "CommonChatPromptTemplate", + "type_cls": "dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate", + "label": "Common Chat Prompt Template", + "custom_label": null, + "name": "common_chat_prompt_template", + "description": "The operator to build the prompt with static prompt.", + "category": "prompt", + "category_label": "Prompt", + "flow_type": "resource", + "icon": null, + "documentation_url": null, + "id": "resource_dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0", + "ui_size": "large" + }, + "resource_type": "instance", + "parent_cls": [ + "dbgpt.core.interface.operators.prompt_operator.CommonChatPromptTemplate", + "dbgpt.core.interface.prompt.ChatPromptTemplate", + "dbgpt.core.interface.prompt.BasePromptTemplate", + "pydantic.main.BaseModel" + ], + "parameters": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "System Message", + "name": "system_message", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "You are a helpful AI Assistant.", + "placeholder": null, + "description": "The system message.", + "value": "You are a helpful AI assistant.\nBased on the known information below, provide users with professional and concise answers to their questions.\nconstraints:\n 1.Ensure to include original markdown formatting elements such as images, links, tables, or code blocks without alteration in the response if they are present in the provided information.\n For example, image format should be ![image.png](xxx), link format [xxx](xxx), table format should be represented with |xxx|xxx|xxx|, and code format with xxx.\n 2.If the information available in the knowledge base is insufficient to answer the question, state clearly: \"The content provided in the knowledge base is not enough to answer this question,\" and avoid making up answers.\n 3.When responding, it is best to summarize the points in the order of 1, 2, 3, And displayed in markdwon format.\n\nknown information: \n{context}\n\nuser question:\n{user_input}\n\nwhen answering, use the same language as the \"user\".", + "options": null, + "ui": { + "refresh": false, + "refresh_depends": null, + "ui_type": "text_area", + "size": "large", + "attr": { + "disabled": false, + "status": null, + "prefix": null, + "suffix": null, + "show_count": null, + "max_length": null, + "auto_size": { + "min_rows": 2, + "max_rows": 20 + } + }, + "editor": { + "width": 800, + "height": 400 + } + } + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Message placeholder", + "name": "message_placeholder", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "chat_history", + "placeholder": null, + "description": "The chat history message placeholder.", + "value": null, + "options": null + }, + { + "type_name": "str", + "type_cls": "builtins.str", + "dynamic": false, + "dynamic_minimum": 0, + "label": "Human Message", + "name": "human_message", + "is_list": false, + "category": "common", + "resource_type": "instance", + "optional": true, + "default": "{user_input}", + "placeholder": "{user_input}", + "description": "The human message.", + "value": null, + "options": null, + "ui": { + "refresh": false, + "refresh_depends": null, + "ui_type": "text_area", + "size": "large", + "attr": { + "disabled": false, + "status": null, + "prefix": null, + "suffix": null, + "show_count": null, + "max_length": null, + "auto_size": { + "min_rows": 2, + "max_rows": 20 + } + }, + "editor": { + "width": 800, + "height": 400 + } + } + } + ] + } + }, + { + "width": 320, + "height": 235, + "id": "operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0", + "position": { + "x": 1087.8490700167088, + "y": 389.9348086323575, + "zoom": 0.0 + }, + "type": "customNode", + "position_absolute": { + "x": 1087.8490700167088, + "y": 389.9348086323575, + "zoom": 0.0 + }, + "data": { + "label": "OpenAI Streaming Output Operator", + "custom_label": null, + "name": "openai_streaming_output_operator", + "description": "The OpenAI streaming LLM operator.", + "category": "output_parser", + "category_label": "Output Parser", + "flow_type": "operator", + "icon": null, + "documentation_url": null, + "id": "operator_openai_streaming_output_operator___$$___output_parser___$$___v1_0", + "tags": { + "order": "higher-order", + "ui_version": "flow2.0" + }, + "operator_type": "transform_stream", + "inputs": [ + { + "type_name": "ModelOutput", + "type_cls": "dbgpt.core.interface.llm.ModelOutput", + "label": "Upstream Model Output", + "custom_label": null, + "name": "model_output", + "description": "The model output of upstream.", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": true, + "mappers": null + } + ], + "outputs": [ + { + "type_name": "str", + "type_cls": "builtins.str", + "label": "Model Output", + "custom_label": null, + "name": "model_output", + "description": "The model output after transformed to openai stream format.", + "dynamic": false, + "dynamic_minimum": 0, + "is_list": true, + "mappers": null + } + ], + "version": "v1", + "type_name": "OpenAIStreamingOutputOperator", + "type_cls": "dbgpt.model.utils.chatgpt_utils.OpenAIStreamingOutputOperator", + "parameters": [] + } + } + ] + } + } +} \ No newline at end of file From d07cce603add42f58eb8a5c25a1dc454f9963180 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Fri, 30 Aug 2024 17:54:34 +0800 Subject: [PATCH 15/60] chore: update SaveFlowModal component to use useRouter hook and add useEffect for id state --- .../flow/canvas-modal/save-flow-modal.tsx | 23 ++++++++----- web/locales/en/flow.ts | 1 + web/locales/zh/flow.ts | 1 + web/pages/construct/flow/canvas/index.tsx | 33 +++++++++++-------- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/web/components/flow/canvas-modal/save-flow-modal.tsx b/web/components/flow/canvas-modal/save-flow-modal.tsx index ba37afcf5..43d43ac61 100644 --- a/web/components/flow/canvas-modal/save-flow-modal.tsx +++ b/web/components/flow/canvas-modal/save-flow-modal.tsx @@ -2,8 +2,8 @@ import { addFlow, apiInterceptors, updateFlowById } from '@/client/api'; import { IFlowData, IFlowUpdateParam } from '@/types/flow'; import { mapHumpToUnderline } from '@/utils/flow'; import { Button, Checkbox, Form, Input, Modal, Space, message } from 'antd'; -import { useSearchParams } from 'next/navigation'; -import { useState } from 'react'; +import { useRouter } from 'next/router'; +import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { ReactFlowInstance } from 'reactflow'; @@ -22,13 +22,18 @@ export const SaveFlowModal: React.FC = ({ flowInfo, setIsSaveFlowModalOpen, }) => { - const [deploy, setDeploy] = useState(true); const { t } = useTranslation(); - const searchParams = useSearchParams(); - const id = searchParams?.get('id') || ''; + const router = useRouter(); const [form] = Form.useForm(); const [messageApi, contextHolder] = message.useMessage(); + const [deploy, setDeploy] = useState(true); + const [id, setId] = useState(router.query.id || ''); + + useEffect(() => { + setId(router.query.id || ''); + }, [router.query.id]); + function onLabelChange(e: React.ChangeEvent) { const label = e.target.value; // replace spaces with underscores, convert uppercase letters to lowercase, remove characters other than digits, letters, _, and -. @@ -45,12 +50,12 @@ export const SaveFlowModal: React.FC = ({ if (id) { const [, , res] = await apiInterceptors( - updateFlowById(id, { + updateFlowById(id.toString(), { name, label, description, editable, - uid: id, + uid: id.toString(), flow_data: reactFlowObject, state, }), @@ -72,10 +77,10 @@ export const SaveFlowModal: React.FC = ({ state, }), ); + if (res?.uid) { messageApi.success(t('save_flow_success')); - const history = window.history; - history.pushState(null, '', `/flow/canvas?id=${res.uid}`); + router.push(`/construct/flow/canvas?id=${res.uid}`, undefined, { shallow: true }); } } setIsSaveFlowModalOpen(false); diff --git a/web/locales/en/flow.ts b/web/locales/en/flow.ts index 18740bf5e..fffe4a6a3 100644 --- a/web/locales/en/flow.ts +++ b/web/locales/en/flow.ts @@ -16,4 +16,5 @@ export const FlowEn = { Export_File_Format: 'File_Format', Yes: 'Yes', No: 'No', + Please_Add_Nodes_First: 'Please add nodes first', }; diff --git a/web/locales/zh/flow.ts b/web/locales/zh/flow.ts index fda54cd50..9ce76d733 100644 --- a/web/locales/zh/flow.ts +++ b/web/locales/zh/flow.ts @@ -16,4 +16,5 @@ export const FlowZn = { Export_File_Format: '文件格式', Yes: '是', No: '否', + Please_Add_Nodes_First: '请先添加节点', }; diff --git a/web/pages/construct/flow/canvas/index.tsx b/web/pages/construct/flow/canvas/index.tsx index c533290aa..5e5676c66 100644 --- a/web/pages/construct/flow/canvas/index.tsx +++ b/web/pages/construct/flow/canvas/index.tsx @@ -30,6 +30,7 @@ const edgeTypes = { buttonedge: ButtonEdge }; const Canvas: React.FC = () => { const { t } = useTranslation(); + const [messageApi, contextHolder] = message.useMessage(); const searchParams = useSearchParams(); const id = searchParams?.get('id') || ''; @@ -152,22 +153,24 @@ const Canvas: React.FC = () => { function onSave() { const flowData = reactFlow.toObject() as IFlowData; const [check, node, message] = checkFlowDataRequied(flowData); + + if (!node) { + messageApi.open({ + type: 'warning', + content: t('Please_Add_Nodes_First'), + }); + return; + } + if (!check && message) { setNodes(nds => - nds.map(item => { - if (item.id === node?.id) { - item.data = { - ...item.data, - invalid: true, - }; - } else { - item.data = { - ...item.data, - invalid: false, - }; - } - return item; - }), + nds.map(item => ({ + ...item, + data: { + ...item.data, + invalid: item.id === node?.id, + }, + })), ); return notification.error({ message: 'Error', @@ -274,6 +277,8 @@ const Canvas: React.FC = () => { isImportModalOpen={isImportModalOpen} setIsImportFlowModalOpen={setIsImportFlowModalOpen} /> + + {contextHolder} ); }; From 0392ab6682e7393cd5a292f96ef9cee2e65c98bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Fri, 30 Aug 2024 18:11:42 +0800 Subject: [PATCH 16/60] feat:update deploy state in SaveFlowModal component --- web/components/flow/canvas-modal/save-flow-modal.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/components/flow/canvas-modal/save-flow-modal.tsx b/web/components/flow/canvas-modal/save-flow-modal.tsx index 43d43ac61..64faac882 100644 --- a/web/components/flow/canvas-modal/save-flow-modal.tsx +++ b/web/components/flow/canvas-modal/save-flow-modal.tsx @@ -27,7 +27,7 @@ export const SaveFlowModal: React.FC = ({ const [form] = Form.useForm(); const [messageApi, contextHolder] = message.useMessage(); - const [deploy, setDeploy] = useState(true); + const [deploy, setDeploy] = useState(false); const [id, setId] = useState(router.query.id || ''); useEffect(() => { From 6030fe1579951ad29c3a8381e16b80e62477f48e Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 30 Aug 2024 18:22:07 +0800 Subject: [PATCH 17/60] chore: Fix import flow error --- dbgpt/serve/flow/api/endpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 5d3ebd75d..68e3a815e 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -517,7 +517,7 @@ async def import_flow( raise HTTPException( status_code=400, detail="invalid json file, missing 'flow' key" ) - flow = _parse_flow_template_from_json(json_dict["flow"]) + flow = _parse_flow_template_from_json(json_dict) elif file_extension == "zip": from ..service.share_utils import _parse_flow_from_zip_file From 293fbb0f9ba05f614d9dec78c6054bcc57927d66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Fri, 30 Aug 2024 20:06:43 +0800 Subject: [PATCH 18/60] chore: Remove unnecessary code in Flow component --- web/pages/construct/flow/index.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/pages/construct/flow/index.tsx b/web/pages/construct/flow/index.tsx index 4d8fcb42a..f82306bd4 100644 --- a/web/pages/construct/flow/index.tsx +++ b/web/pages/construct/flow/index.tsx @@ -126,7 +126,6 @@ function Flow() { copyFlowTemp.current = flow; form.setFieldValue('label', `${flow.label} Copy`); form.setFieldValue('name', `${flow.name}_copy`); - setDeploy(true); setEditable(true); setShowModal(true); }; @@ -256,8 +255,10 @@ function Flow() { + { setShowModal(false); From f8789c0e38e55101b83bcfff30618ce568b13a0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Mon, 2 Sep 2024 20:04:01 +0800 Subject: [PATCH 19/60] feat: Add AddFlowVariable component for managing flow variables --- web/components/flow/add-flow-variable.tsx | 151 ++++++++++++++++++++++ web/locales/en/flow.ts | 1 + web/locales/zh/flow.ts | 1 + web/new-components/layout/Construct.tsx | 19 +++ web/pages/construct/flow/canvas/index.tsx | 3 + 5 files changed, 175 insertions(+) create mode 100644 web/components/flow/add-flow-variable.tsx diff --git a/web/components/flow/add-flow-variable.tsx b/web/components/flow/add-flow-variable.tsx new file mode 100644 index 000000000..808fd50f4 --- /dev/null +++ b/web/components/flow/add-flow-variable.tsx @@ -0,0 +1,151 @@ +import { apiInterceptors, getFlowNodes } from '@/client/api'; +import { IFlowNode } from '@/types/flow'; +import { FLOW_NODES_KEY } from '@/utils'; +import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; +import { Button, Form, Input, Modal } from 'antd'; +import React, { useEffect, useState } from 'react'; +import { useTranslation } from 'react-i18next'; + +type GroupType = { category: string; categoryLabel: string; nodes: IFlowNode[] }; + +const AddFlowVariable: React.FC = () => { + const { t } = useTranslation(); + const [operators, setOperators] = useState>([]); + const [resources, setResources] = useState>([]); + const [operatorsGroup, setOperatorsGroup] = useState([]); + const [resourcesGroup, setResourcesGroup] = useState([]); + const [isModalOpen, setIsModalOpen] = useState(false); + + const showModal = () => { + setIsModalOpen(true); + }; + + useEffect(() => { + getNodes(); + }, []); + + async function getNodes() { + const [_, data] = await apiInterceptors(getFlowNodes()); + if (data && data.length > 0) { + localStorage.setItem(FLOW_NODES_KEY, JSON.stringify(data)); + const operatorNodes = data.filter(node => node.flow_type === 'operator'); + const resourceNodes = data.filter(node => node.flow_type === 'resource'); + setOperators(operatorNodes); + setResources(resourceNodes); + setOperatorsGroup(groupNodes(operatorNodes)); + setResourcesGroup(groupNodes(resourceNodes)); + } + } + + function groupNodes(data: IFlowNode[]) { + const groups: GroupType[] = []; + const categoryMap: Record = {}; + data.forEach(item => { + const { category, category_label } = item; + if (!categoryMap[category]) { + categoryMap[category] = { category, categoryLabel: category_label, nodes: [] }; + groups.push(categoryMap[category]); + } + categoryMap[category].nodes.push(item); + }); + return groups; + } + + const formItemLayout = { + labelCol: { + xs: { span: 24 }, + sm: { span: 4 }, + }, + wrapperCol: { + xs: { span: 24 }, + sm: { span: 20 }, + }, + }; + + const formItemLayoutWithOutLabel = { + wrapperCol: { + xs: { span: 24, offset: 0 }, + sm: { span: 20, offset: 2 }, + }, + }; + + const onFinish = (values: any) => { + console.log('Received values of form:', values); + }; + + return ( + <> + + + + + + )} + + + + + + + + + ); +}; + +export default AddFlowVariable; diff --git a/web/locales/en/flow.ts b/web/locales/en/flow.ts index fffe4a6a3..396e6d5ee 100644 --- a/web/locales/en/flow.ts +++ b/web/locales/en/flow.ts @@ -17,4 +17,5 @@ export const FlowEn = { Yes: 'Yes', No: 'No', Please_Add_Nodes_First: 'Please add nodes first', + Add_Global_Variable_of_Flow: 'Add global variable of flow', }; diff --git a/web/locales/zh/flow.ts b/web/locales/zh/flow.ts index 9ce76d733..338c6df50 100644 --- a/web/locales/zh/flow.ts +++ b/web/locales/zh/flow.ts @@ -17,4 +17,5 @@ export const FlowZn = { Yes: '是', No: '否', Please_Add_Nodes_First: '请先添加节点', + Add_Global_Variable_of_Flow: '添加 Flow 全局变量', }; diff --git a/web/new-components/layout/Construct.tsx b/web/new-components/layout/Construct.tsx index 8d76e62fd..f68ef3ccf 100644 --- a/web/new-components/layout/Construct.tsx +++ b/web/new-components/layout/Construct.tsx @@ -12,6 +12,7 @@ import { t } from 'i18next'; import { useRouter } from 'next/router'; import React from 'react'; import './style.css'; + function ConstructLayout({ children }: { children: React.ReactNode }) { const items = [ { @@ -19,6 +20,15 @@ function ConstructLayout({ children }: { children: React.ReactNode }) { name: t('App'), path: '/app', icon: , + // operations: ( + // + // ), }, { key: 'flow', @@ -102,6 +112,15 @@ function ConstructLayout({ children }: { children: React.ReactNode }) { onTabClick={key => { router.push(`/construct/${key}`); }} + // tabBarExtraContent={ + // + // } /> diff --git a/web/pages/construct/flow/canvas/index.tsx b/web/pages/construct/flow/canvas/index.tsx index 5e5676c66..5957074ea 100644 --- a/web/pages/construct/flow/canvas/index.tsx +++ b/web/pages/construct/flow/canvas/index.tsx @@ -17,6 +17,7 @@ import ReactFlow, { useReactFlow, } from 'reactflow'; // import AddNodes from '@/components/flow/add-nodes'; +import AddFlowVariable from '@/components/flow/add-flow-variable'; import AddNodesSider from '@/components/flow/add-nodes-sider'; import ButtonEdge from '@/components/flow/button-edge'; import { ExportFlowModal, ImportFlowModal, SaveFlowModal } from '@/components/flow/canvas-modal'; @@ -249,7 +250,9 @@ const Canvas: React.FC = () => { > + {/* */} + From 4edc64a4c936f23059c4941008ff416bb57c7178 Mon Sep 17 00:00:00 2001 From: yanzhiyong <932374019@qq.com> Date: Tue, 3 Sep 2024 00:58:53 +0800 Subject: [PATCH 20/60] feat: import update canvas flow --- web/components/flow/canvas-modal/import-flow-modal.tsx | 7 +++++-- web/components/flow/canvas-node.tsx | 4 ---- web/components/flow/node-renderer/upload.tsx | 1 + 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/web/components/flow/canvas-modal/import-flow-modal.tsx b/web/components/flow/canvas-modal/import-flow-modal.tsx index fbf7e87df..8832c7945 100644 --- a/web/components/flow/canvas-modal/import-flow-modal.tsx +++ b/web/components/flow/canvas-modal/import-flow-modal.tsx @@ -4,7 +4,7 @@ import { Button, Form, GetProp, Modal, Radio, Space, Upload, UploadFile, UploadP import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { Edge, Node } from 'reactflow'; - +import CanvasWrapper from '@/pages/construct/flow/canvas/index'; type Props = { isImportModalOpen: boolean; setNodes: React.Dispatch[]>>; @@ -37,9 +37,12 @@ export const ImportFlowModal: React.FC = ({ isImportModalOpen, setIsImpor const [, , res] = await apiInterceptors(importFlow(formData)); if (res?.success) { - messageApi.success(t('Export_Flow_Success')); + messageApi.success(t('Import_Flow_Success')); + localStorage.setItem('importFlowData', JSON.stringify(res?.data)); + CanvasWrapper(res?.data) } else if (res?.err_msg) { messageApi.error(res?.err_msg); + } setIsImportFlowModalOpen(false); }; diff --git a/web/components/flow/canvas-node.tsx b/web/components/flow/canvas-node.tsx index 37449337d..568c51ebf 100644 --- a/web/components/flow/canvas-node.tsx +++ b/web/components/flow/canvas-node.tsx @@ -19,7 +19,6 @@ type CanvasNodeProps = { function TypeLabel({ label }: { label: string }) { return
{label}
; } -const forceTypeList = ['file', 'multiple_files', 'time']; const CanvasNode: React.FC = ({ data }) => { const node = data; @@ -128,9 +127,6 @@ const CanvasNode: React.FC = ({ data }) => { function onParameterValuesChange(changedValues: any, allValues: any) { const [changedKey, changedVal] = Object.entries(changedValues)[0]; - if (!allValues?.force && forceTypeList.includes(changedKey)) { - return; - } updateCurrentNodeValue(changedKey, changedVal); if (changedVal) { updateDependsNodeValue(changedKey, changedVal); diff --git a/web/components/flow/node-renderer/upload.tsx b/web/components/flow/node-renderer/upload.tsx index 93232abdc..83cfa718e 100644 --- a/web/components/flow/node-renderer/upload.tsx +++ b/web/components/flow/node-renderer/upload.tsx @@ -16,6 +16,7 @@ export const renderUpload = (params: Props) => { const { t } = useTranslation(); const urlList = useRef([]); const { data, formValuesChange } = params; +console.log(data); const attr = convertKeysToCamelCase(data.ui?.attr || {}); const [uploading, setUploading] = useState(false); From 1bc77f91488e04a37b7f9e9cb4799f573ca15dc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 3 Sep 2024 02:42:23 +0800 Subject: [PATCH 21/60] feat: Update AddFlowVariable component to include parameter management --- web/components/flow/add-flow-variable.tsx | 280 ++++++++++++------ .../flow/canvas-modal/export-flow-modal.tsx | 1 - .../flow/canvas-modal/import-flow-modal.tsx | 1 - .../flow/canvas-modal/save-flow-modal.tsx | 3 +- web/locales/en/flow.ts | 1 + web/locales/zh/flow.ts | 1 + 6 files changed, 185 insertions(+), 102 deletions(-) diff --git a/web/components/flow/add-flow-variable.tsx b/web/components/flow/add-flow-variable.tsx index 808fd50f4..cc83e276c 100644 --- a/web/components/flow/add-flow-variable.tsx +++ b/web/components/flow/add-flow-variable.tsx @@ -1,78 +1,115 @@ -import { apiInterceptors, getFlowNodes } from '@/client/api'; -import { IFlowNode } from '@/types/flow'; -import { FLOW_NODES_KEY } from '@/utils'; +// import { IFlowNode } from '@/types/flow'; import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; -import { Button, Form, Input, Modal } from 'antd'; -import React, { useEffect, useState } from 'react'; +import { Button, Form, Input, Modal, Select, Space } from 'antd'; +import React, { useState } from 'react'; import { useTranslation } from 'react-i18next'; -type GroupType = { category: string; categoryLabel: string; nodes: IFlowNode[] }; +// ype GroupType = { category: string; categoryLabel: string; nodes: IFlowNode[] }; +type ValueType = 'str' | 'int' | 'float' | 'bool' | 'ref'; + +const { Option } = Select; + +const DAG_PARAM_KEY = 'dbgpt.core.flow.params'; +const DAG_PARAM_SCOPE = 'flow_priv'; const AddFlowVariable: React.FC = () => { const { t } = useTranslation(); - const [operators, setOperators] = useState>([]); - const [resources, setResources] = useState>([]); - const [operatorsGroup, setOperatorsGroup] = useState([]); - const [resourcesGroup, setResourcesGroup] = useState([]); + // const [operators, setOperators] = useState>([]); + // const [resources, setResources] = useState>([]); + // const [operatorsGroup, setOperatorsGroup] = useState([]); + // const [resourcesGroup, setResourcesGroup] = useState([]); const [isModalOpen, setIsModalOpen] = useState(false); + const [form] = Form.useForm(); // const [form] = Form.useForm(); const showModal = () => { setIsModalOpen(true); }; - useEffect(() => { - getNodes(); - }, []); - - async function getNodes() { - const [_, data] = await apiInterceptors(getFlowNodes()); - if (data && data.length > 0) { - localStorage.setItem(FLOW_NODES_KEY, JSON.stringify(data)); - const operatorNodes = data.filter(node => node.flow_type === 'operator'); - const resourceNodes = data.filter(node => node.flow_type === 'resource'); - setOperators(operatorNodes); - setResources(resourceNodes); - setOperatorsGroup(groupNodes(operatorNodes)); - setResourcesGroup(groupNodes(resourceNodes)); - } - } + // TODO: get keys + // useEffect(() => { + // getNodes(); + // }, []); - function groupNodes(data: IFlowNode[]) { - const groups: GroupType[] = []; - const categoryMap: Record = {}; - data.forEach(item => { - const { category, category_label } = item; - if (!categoryMap[category]) { - categoryMap[category] = { category, categoryLabel: category_label, nodes: [] }; - groups.push(categoryMap[category]); - } - categoryMap[category].nodes.push(item); - }); - return groups; - } + // async function getNodes() { + // const [_, data] = await apiInterceptors(getFlowNodes()); + // if (data && data.length > 0) { + // localStorage.setItem(FLOW_NODES_KEY, JSON.stringify(data)); + // const operatorNodes = data.filter(node => node.flow_type === 'operator'); + // const resourceNodes = data.filter(node => node.flow_type === 'resource'); + // setOperators(operatorNodes); + // setResources(resourceNodes); + // setOperatorsGroup(groupNodes(operatorNodes)); + // setResourcesGroup(groupNodes(resourceNodes)); + // } + // } - const formItemLayout = { - labelCol: { - xs: { span: 24 }, - sm: { span: 4 }, - }, - wrapperCol: { - xs: { span: 24 }, - sm: { span: 20 }, - }, - }; - - const formItemLayoutWithOutLabel = { - wrapperCol: { - xs: { span: 24, offset: 0 }, - sm: { span: 20, offset: 2 }, - }, - }; + // function groupNodes(data: IFlowNode[]) { + // const groups: GroupType[] = []; + // const categoryMap: Record = {}; + // data.forEach(item => { + // const { category, category_label } = item; + // if (!categoryMap[category]) { + // categoryMap[category] = { category, categoryLabel: category_label, nodes: [] }; + // groups.push(categoryMap[category]); + // } + // categoryMap[category].nodes.push(item); + // }); + // return groups; + // } const onFinish = (values: any) => { console.log('Received values of form:', values); }; + function onNameChange(e: React.ChangeEvent, index: number) { + const name = e.target.value; + + const result = name + ?.split('_') + ?.map(word => word.charAt(0).toUpperCase() + word.slice(1)) + ?.join(' '); + + form.setFields([ + { + name: ['parameters', index, 'label'], + value: result, + }, + ]); + + // change value to ref + const type = form.getFieldValue(['parameters', index, 'value_type']); + + if (type === 'ref') { + const parameters = form.getFieldValue('parameters'); + const param = parameters?.[index]; + + if (param) { + const { name = '' } = param; + param.value = `${DAG_PARAM_KEY}:${name}@scope:${DAG_PARAM_SCOPE}`; + + form.setFieldsValue({ + parameters: [...parameters], + }); + } + } + } + + function onValueTypeChange(type: ValueType, index: number) { + if (type === 'ref') { + const parameters = form.getFieldValue('parameters'); + const param = parameters?.[index]; + + if (param) { + const { name = '' } = param; + param.value = `${DAG_PARAM_KEY}:${name}@scope:${DAG_PARAM_SCOPE}`; + + form.setFieldsValue({ + parameters: [...parameters], + }); + } + } + } + return ( <> - - )} - - - + + + + + diff --git a/web/components/flow/canvas-modal/export-flow-modal.tsx b/web/components/flow/canvas-modal/export-flow-modal.tsx index 31850dc56..0d056abac 100644 --- a/web/components/flow/canvas-modal/export-flow-modal.tsx +++ b/web/components/flow/canvas-modal/export-flow-modal.tsx @@ -43,7 +43,6 @@ export const ExportFlowModal: React.FC = ({ return ( <> setIsExportFlowModalOpen(false)} diff --git a/web/components/flow/canvas-modal/import-flow-modal.tsx b/web/components/flow/canvas-modal/import-flow-modal.tsx index fbf7e87df..803f37d92 100644 --- a/web/components/flow/canvas-modal/import-flow-modal.tsx +++ b/web/components/flow/canvas-modal/import-flow-modal.tsx @@ -61,7 +61,6 @@ export const ImportFlowModal: React.FC = ({ isImportModalOpen, setIsImpor return ( <> setIsImportFlowModalOpen(false)} diff --git a/web/components/flow/canvas-modal/save-flow-modal.tsx b/web/components/flow/canvas-modal/save-flow-modal.tsx index 64faac882..708f6f768 100644 --- a/web/components/flow/canvas-modal/save-flow-modal.tsx +++ b/web/components/flow/canvas-modal/save-flow-modal.tsx @@ -89,7 +89,6 @@ export const SaveFlowModal: React.FC = ({ return ( <> { @@ -142,7 +141,7 @@ export const SaveFlowModal: React.FC = ({