diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 070c27b76..5837fee1b 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -458,6 +458,54 @@ def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any: else: return prompt_define_response + def _generate_numbered_list(self) -> str: + """this function is moved from excel_analyze/chat.py,and used by subclass. + Returns: + + """ + antv_charts = [ + {"response_line_chart": "used to display comparative trend analysis data"}, + { + "response_pie_chart": "suitable for scenarios such as proportion and distribution statistics" + }, + { + "response_table": "suitable for display with many display columns or non-numeric columns" + }, + # {"response_data_text":" the default display method, suitable for single-line or simple content display"}, + { + "response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc." + }, + { + "response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc." + }, + { + "response_donut_chart": "Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc." + }, + { + "response_area_chart": "Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc." + }, + { + "response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc." + }, + ] + + # command_strings = [] + # if CFG.command_disply: + # for name, item in CFG.command_disply.commands.items(): + # if item.enabled: + # command_strings.append(f"{name}:{item.description}") + # command_strings += [ + # str(item) + # for item in CFG.command_disply.commands.values() + # if item.enabled + # ] + return "\n".join( + f"{key}:{value}" + for dict_item in antv_charts + for key, value in dict_item.items() + ) + + def _build_model_operator( is_stream: bool = False, dag_name: str = "llm_model_dag" @@ -665,4 +713,4 @@ def _load_history_messages( ModelMessage(role=message_type, content=message_content) ) - return history_text if str_message else history_messages + return history_text if str_message else history_messages \ No newline at end of file diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index bede9bc40..f281150f9 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -50,49 +50,6 @@ def __init__(self, chat_param: Dict): self.api_call = ApiCall(display_registry=CFG.command_disply) super().__init__(chat_param=chat_param) - def _generate_numbered_list(self) -> str: - antv_charts = [ - {"response_line_chart": "used to display comparative trend analysis data"}, - { - "response_pie_chart": "suitable for scenarios such as proportion and distribution statistics" - }, - { - "response_table": "suitable for display with many display columns or non-numeric columns" - }, - # {"response_data_text":" the default display method, suitable for single-line or simple content display"}, - { - "response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc." - }, - { - "response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc." - }, - { - "response_donut_chart": "Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc." - }, - { - "response_area_chart": "Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc." - }, - { - "response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc." - }, - ] - - # command_strings = [] - # if CFG.command_disply: - # for name, item in CFG.command_disply.commands.items(): - # if item.enabled: - # command_strings.append(f"{name}:{item.description}") - # command_strings += [ - # str(item) - # for item in CFG.command_disply.commands.values() - # if item.enabled - # ] - return "\n".join( - f"{key}:{value}" - for dict_item in antv_charts - for key, value in dict_item.items() - ) - @trace() async def generate_input_values(self) -> Dict: input_values = { diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 00ddce1a2..ef6601b9e 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -43,6 +43,7 @@ def __init__(self, chat_param: Dict): self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.top_k: int = 50 + self.api_call = ApiCall(display_registry=CFG.command_disply) @trace() async def generate_input_values(self) -> Dict: @@ -76,6 +77,7 @@ async def generate_input_values(self) -> Dict: "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_infos, + "display_type": self._generate_numbered_list(), } return input_values diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index bd1dd9de8..c57ea1419 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -14,9 +14,10 @@ class SqlAction(NamedTuple): sql: str thoughts: Dict + display: str def to_dict(self) -> Dict[str, Dict]: - return {"sql": self.sql, "thoughts": self.thoughts} + return {"sql": self.sql, "thoughts": self.thoughts, "display": self.display, } logger = logging.getLogger(__name__) @@ -40,19 +41,21 @@ def parse_prompt_response(self, model_out_text): logger.info(f"clean prompt response: {clean_str}") # Compatible with community pure sql output model if self.is_sql_statement(clean_str): - return SqlAction(clean_str, "") + return SqlAction(clean_str, "", "") else: try: - response = json.loads(clean_str) + response = json.loads(clean_str, strict=False) for key in sorted(response): if key.strip() == "sql": sql = response[key] if key.strip() == "thoughts": thoughts = response[key] - return SqlAction(sql, thoughts) + if key.strip() == "display_type": + display = response[key] + return SqlAction(sql, thoughts, display) except Exception as e: - logger.error("json load faild") - return SqlAction("", clean_str) + logger.error(f"json load failed:{clean_str}") + return SqlAction("", clean_str, "") def parse_view_response(self, speak, data, prompt_response) -> str: param = {} @@ -63,7 +66,7 @@ def parse_view_response(self, speak, data, prompt_response) -> str: return f"""{speak}""" df = data(prompt_response.sql) - param["type"] = "response_table" + param["type"] = prompt_response.display param["sql"] = prompt_response.sql param["data"] = json.loads( df.to_json(orient="records", date_format="iso", date_unit="s") diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index e7f0e0ee0..b400fbd14 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -25,6 +25,7 @@ 3.You can only use the tables provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql queries." It is prohibited to fabricate information at will. 4.Please be careful not to mistake the relationship between tables and columns when generating SQL. 5.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions. + 6.Please choose the best one from the display methods given below for data rendering, and put the type name into the name parameter value that returns the required format. If you cannot find the most suitable one, use 'Table' as the display method. , the available data display methods are as follows: {display_type} User Question: {user_input} @@ -47,6 +48,7 @@ 3. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 4. 请注意生成SQL时不要弄错表和列的关系 5. 请检查SQL的正确性,并保证正确的情况下优化查询性能 + 6.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {display_type} 用户问题: {user_input} 请一步步思考并按照以下JSON格式回复: @@ -66,6 +68,7 @@ RESPONSE_FORMAT_SIMPLE = { "thoughts": "thoughts summary to say to user", "sql": "SQL Query to run", + "display_type": "Data display method", } PROMPT_SEP = SeparatorStyle.SINGLE.value