Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add chart for chat_data scene #903

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion pilot/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,54 @@ def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
else:
return prompt_define_response

def _generate_numbered_list(self) -> str:
"""this function is moved from excel_analyze/chat.py,and used by subclass.
Returns:

"""
antv_charts = [
{"response_line_chart": "used to display comparative trend analysis data"},
{
"response_pie_chart": "suitable for scenarios such as proportion and distribution statistics"
},
{
"response_table": "suitable for display with many display columns or non-numeric columns"
},
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
{
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
},
{
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
},
{
"response_donut_chart": "Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc."
},
{
"response_area_chart": "Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc."
},
{
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
},
]

# command_strings = []
# if CFG.command_disply:
# for name, item in CFG.command_disply.commands.items():
# if item.enabled:
# command_strings.append(f"{name}:{item.description}")
# command_strings += [
# str(item)
# for item in CFG.command_disply.commands.values()
# if item.enabled
# ]
return "\n".join(
f"{key}:{value}"
for dict_item in antv_charts
for key, value in dict_item.items()
)



def _build_model_operator(
is_stream: bool = False, dag_name: str = "llm_model_dag"
Expand Down Expand Up @@ -665,4 +713,4 @@ def _load_history_messages(
ModelMessage(role=message_type, content=message_content)
)

return history_text if str_message else history_messages
return history_text if str_message else history_messages
43 changes: 0 additions & 43 deletions pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,49 +50,6 @@ def __init__(self, chat_param: Dict):
self.api_call = ApiCall(display_registry=CFG.command_disply)
super().__init__(chat_param=chat_param)

def _generate_numbered_list(self) -> str:
antv_charts = [
{"response_line_chart": "used to display comparative trend analysis data"},
{
"response_pie_chart": "suitable for scenarios such as proportion and distribution statistics"
},
{
"response_table": "suitable for display with many display columns or non-numeric columns"
},
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
{
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
},
{
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
},
{
"response_donut_chart": "Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc."
},
{
"response_area_chart": "Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc."
},
{
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
},
]

# command_strings = []
# if CFG.command_disply:
# for name, item in CFG.command_disply.commands.items():
# if item.enabled:
# command_strings.append(f"{name}:{item.description}")
# command_strings += [
# str(item)
# for item in CFG.command_disply.commands.values()
# if item.enabled
# ]
return "\n".join(
f"{key}:{value}"
for dict_item in antv_charts
for key, value in dict_item.items()
)

@trace()
async def generate_input_values(self) -> Dict:
input_values = {
Expand Down
2 changes: 2 additions & 0 deletions pilot/scene/chat_db/auto_execute/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, chat_param: Dict):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)

self.top_k: int = 50
self.api_call = ApiCall(display_registry=CFG.command_disply)

@trace()
async def generate_input_values(self) -> Dict:
Expand Down Expand Up @@ -76,6 +77,7 @@ async def generate_input_values(self) -> Dict:
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": table_infos,
"display_type": self._generate_numbered_list(),
}
return input_values

Expand Down
17 changes: 10 additions & 7 deletions pilot/scene/chat_db/auto_execute/out_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
class SqlAction(NamedTuple):
sql: str
thoughts: Dict
display: str

def to_dict(self) -> Dict[str, Dict]:
return {"sql": self.sql, "thoughts": self.thoughts}
return {"sql": self.sql, "thoughts": self.thoughts, "display": self.display, }


logger = logging.getLogger(__name__)
Expand All @@ -40,19 +41,21 @@ def parse_prompt_response(self, model_out_text):
logger.info(f"clean prompt response: {clean_str}")
# Compatible with community pure sql output model
if self.is_sql_statement(clean_str):
return SqlAction(clean_str, "")
return SqlAction(clean_str, "", "")
else:
try:
response = json.loads(clean_str)
response = json.loads(clean_str, strict=False)
for key in sorted(response):
if key.strip() == "sql":
sql = response[key]
if key.strip() == "thoughts":
thoughts = response[key]
return SqlAction(sql, thoughts)
if key.strip() == "display_type":
display = response[key]
return SqlAction(sql, thoughts, display)
except Exception as e:
logger.error("json load faild")
return SqlAction("", clean_str)
logger.error(f"json load failed:{clean_str}")
return SqlAction("", clean_str, "")

def parse_view_response(self, speak, data, prompt_response) -> str:
param = {}
Expand All @@ -63,7 +66,7 @@ def parse_view_response(self, speak, data, prompt_response) -> str:
return f"""{speak}"""

df = data(prompt_response.sql)
param["type"] = "response_table"
param["type"] = prompt_response.display
param["sql"] = prompt_response.sql
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
Expand Down
3 changes: 3 additions & 0 deletions pilot/scene/chat_db/auto_execute/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
3.You can only use the tables provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql queries." It is prohibited to fabricate information at will.
4.Please be careful not to mistake the relationship between tables and columns when generating SQL.
5.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions.
6.Please choose the best one from the display methods given below for data rendering, and put the type name into the name parameter value that returns the required format. If you cannot find the most suitable one, use 'Table' as the display method. , the available data display methods are as follows: {display_type}

User Question:
{user_input}
Expand All @@ -47,6 +48,7 @@
3. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
4. 请注意生成SQL时不要弄错表和列的关系
5. 请检查SQL的正确性,并保证正确的情况下优化查询性能
6.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {display_type}
用户问题:
{user_input}
请一步步思考并按照以下JSON格式回复:
Expand All @@ -66,6 +68,7 @@
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
"display_type": "Data display method",
}

PROMPT_SEP = SeparatorStyle.SINGLE.value
Expand Down