Skip to content

Commit

Permalink
feat(ChatKnowledge): Add document summary and Integrate ant-v (#809)
Browse files Browse the repository at this point in the history
ChatKnowledge

- Support document summary
- Fix Milvus reference bug
- Fix Chunk list page bug

ChatDB

- Merge the scenes ChatDB and ChatData into one scene.
- Integrate ant-v 

ChatExcel
- Adjusted the prompt assembly to support ChatExcel functionality.

Close #792
  • Loading branch information
csunny authored Nov 20, 2023
2 parents ecc5d5d + aad063b commit 41430ef
Show file tree
Hide file tree
Showing 143 changed files with 2,383 additions and 1,180 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
copyright = "2023, csunny"
author = "csunny"

version = "👏👏 0.4.1"
version = "👏👏 0.4.2"
html_title = project + " " + version

# -- General configuration ---------------------------------------------------
Expand Down
149 changes: 117 additions & 32 deletions pilot/base_modules/agent/commands/command_mange.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import json
import logging
import xml.etree.ElementTree as ET
import pandas as pd

from pilot.common.json_utils import serialize
from datetime import datetime
from typing import Any, Callable, Optional, List
from pydantic import BaseModel
Expand Down Expand Up @@ -184,14 +186,21 @@ class PluginStatus(BaseModel):
start_time = datetime.now().timestamp() * 1000
end_time: int = None

df: Any = None


class ApiCall:
agent_prefix = "<api-call>"
agent_end = "</api-call>"
name_prefix = "<name>"
name_end = "</name>"

def __init__(self, plugin_generator: Any = None, display_registry: Any = None):
def __init__(
self,
plugin_generator: Any = None,
display_registry: Any = None,
backend_rendering: bool = False,
):
# self.name: str = ""
# self.status: Status = Status.TODO.value
# self.logo_url: str = None
Expand All @@ -204,6 +213,7 @@ def __init__(self, plugin_generator: Any = None, display_registry: Any = None):
self.plugin_generator = plugin_generator
self.display_registry = display_registry
self.start_time = datetime.now().timestamp() * 1000
self.backend_rendering: bool = False

def __repr__(self):
return f"ApiCall(name={self.name}, status={self.status}, args={self.args})"
Expand All @@ -227,7 +237,7 @@ def __is_need_wait_plugin_call(self, api_call_context):
i += 1
return False

def __check_last_plugin_call_ready(self, all_context):
def check_last_plugin_call_ready(self, all_context):
start_agent_count = all_context.count(self.agent_prefix)
end_agent_count = all_context.count(self.agent_end)

Expand All @@ -236,7 +246,14 @@ def __check_last_plugin_call_ready(self, all_context):
return False

def __deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
error_md_tags = ["```", "```python", "```xml", "```json", "```markdown"]
error_md_tags = [
"```",
"```python",
"```xml",
"```json",
"```markdown",
"```sql",
]
if include_end == False:
md_tag_end = ""
else:
Expand All @@ -255,40 +272,25 @@ def __deal_error_md_tags(self, all_context, api_context, include_end: bool = Tru
return all_context

def api_view_context(self, all_context: str, display_mode: bool = False):
error_mk_tags = ["```", "```python", "```xml"]
call_context_map = extract_content_open_ending(
all_context, self.agent_prefix, self.agent_end, True
)
for api_index, api_context in call_context_map.items():
api_status = self.plugin_status_map.get(api_context)
if api_status is not None:
if display_mode:
if api_status.api_result:
all_context = self.__deal_error_md_tags(
all_context, api_context
)
all_context = self.__deal_error_md_tags(all_context, api_context)
if Status.FAILED.value == api_status.status:
all_context = all_context.replace(
api_context, api_status.api_result
api_context,
f'\n<span style="color:red">Error:</span>{api_status.err_msg}\n'
+ self.to_view_antv_vis(api_status),
)
else:
if api_status.status == Status.FAILED.value:
all_context = self.__deal_error_md_tags(
all_context, api_context
)
all_context = all_context.replace(
api_context,
f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """,
)
else:
cost = (api_status.end_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
all_context = self.__deal_error_md_tags(
all_context, api_context
)
all_context = all_context.replace(
api_context,
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
)
all_context = all_context.replace(
api_context, self.to_view_antv_vis(api_status)
)

else:
all_context = self.__deal_error_md_tags(
all_context, api_context, False
Expand All @@ -302,8 +304,8 @@ def api_view_context(self, all_context: str, display_mode: bool = False):
now_time = datetime.now().timestamp() * 1000
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
for tag in error_mk_tags:
all_context = all_context.replace(tag + api_context, api_context)
all_context = self.__deal_error_md_tags(all_context, api_context)

all_context = all_context.replace(
api_context,
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
Expand Down Expand Up @@ -348,18 +350,54 @@ def __to_view_param_str(self, api_status):

if api_status.api_result:
param["result"] = api_status.api_result
return json.dumps(param)

return json.dumps(param, default=serialize, ensure_ascii=False)

def to_view_text(self, api_status: PluginStatus):
api_call_element = ET.Element("dbgpt-view")
api_call_element.text = self.__to_view_param_str(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")

def to_view_antv_vis(self, api_status: PluginStatus):
if self.backend_rendering:
html_table = api_status.df.to_html(
index=False, escape=False, sparsify=False
)
table_str = "".join(html_table.split())
table_str = table_str.replace("\n", " ")
html = f""" \n<div><b>[SQL]{api_status.args["sql"]}</b></div><div class="w-full overflow-auto">{table_str}</div>\n """
return html
else:
api_call_element = ET.Element("chart-view")
api_call_element.attrib["content"] = self.__to_antv_vis_param(api_status)
api_call_element.text = "\n"
# api_call_element.set("content", self.__to_antv_vis_param(api_status))
# api_call_element.text = self.__to_antv_vis_param(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")

# return f'<chart-view content="{self.__to_antv_vis_param(api_status)}">'

def __to_antv_vis_param(self, api_status: PluginStatus):
param = {}
if api_status.name:
param["type"] = api_status.name
if api_status.args:
param["sql"] = api_status.args["sql"]
# if api_status.err_msg:
# param["err_msg"] = api_status.err_msg

if api_status.api_result:
param["data"] = api_status.api_result
else:
param["data"] = []
return json.dumps(param, ensure_ascii=False)

def run(self, llm_text):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
Expand All @@ -379,7 +417,7 @@ def run(self, llm_text):
def run_display_sql(self, llm_text, sql_run_func):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
Expand All @@ -391,6 +429,7 @@ def run_display_sql(self, llm_text, sql_run_func):
param = {
"df": sql_run_func(sql),
}
value.df = param["df"]
if self.display_registry.is_valid_command(value.name):
value.api_result = self.display_registry.call(
value.name, **param
Expand All @@ -406,3 +445,49 @@ def run_display_sql(self, llm_text, sql_run_func):
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text, True)

def display_sql_llmvis(self, llm_text, sql_run_func):
"""
Render charts using the Antv standard protocol
Args:
llm_text: LLM response text
sql_run_func: sql run function
Returns:
ChartView protocol text
"""
try:
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logging.info(f"sql展示执行:{value.name},{value.args}")
try:
sql = value.args["sql"]
if sql is not None and len(sql) > 0:
data_df = sql_run_func(sql)
value.df = data_df
value.api_result = json.loads(
data_df.to_json(
orient="records",
date_format="iso",
date_unit="s",
)
)
value.status = Status.COMPLETED.value
else:
value.status = Status.FAILED.value
value.err_msg = "No executable sql!"

except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
except Exception as e:
logging.error("Api parsing exception", e)
raise ValueError("Api parsing exception," + str(e))

return self.api_view_context(llm_text, True)
8 changes: 4 additions & 4 deletions pilot/base_modules/agent/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
return Result.succ(None)
except Exception as e:
logger.error("Agent Hub Update Error!", e)
return Result.faild(code="E0020", msg=f"Agent Hub Update Error! {e}")
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")


@router.post("/v1/agent/query", response_model=Result[str])
Expand Down Expand Up @@ -133,7 +133,7 @@ async def agent_install(plugin_name: str, user: str = None):
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
return Result.failed(code="E0021", msg=f"Plugin Install Error {e}")


@router.post("/v1/agent/uninstall", response_model=Result[str])
Expand All @@ -147,7 +147,7 @@ async def agent_uninstall(plugin_name: str, user: str = None):
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
return Result.failed(code="E0022", msg=f"Plugin Uninstall Error {e}")


@router.post("/v1/personal/agent/upload", response_model=Result[str])
Expand All @@ -160,4 +160,4 @@ async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = No
return Result.succ(None)
except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")
return Result.failed(code="E0023", msg=f"Upload Personal Plugin Error {e}")
44 changes: 18 additions & 26 deletions pilot/common/chat_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,27 @@ async def llm_chat_response(chat_scene: str, **chat_param):
return chat.stream_call()


def run_async_tasks(
async def run_async_tasks(
tasks: List[Coroutine],
show_progress: bool = False,
progress_bar_desc: str = "Running async tasks",
concurrency_limit: int = None,
) -> List[Any]:
"""Run a list of async tasks."""

tasks_to_execute: List[Any] = tasks
if show_progress:
try:
import nest_asyncio
from tqdm.asyncio import tqdm

nest_asyncio.apply()
loop = asyncio.get_event_loop()

async def _tqdm_gather() -> List[Any]:
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)

tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
return tqdm_outputs
# run the operation w/o tqdm on hitting a fatal
# may occur in some environments where tqdm.asyncio
# is not supported
except Exception:
pass

async def _gather() -> List[Any]:
return await asyncio.gather(*tasks_to_execute)

outputs: List[Any] = asyncio.run(_gather())
return outputs
if concurrency_limit:
semaphore = asyncio.Semaphore(concurrency_limit)

async def _execute_task(task):
async with semaphore:
return await task

# Execute tasks with semaphore limit
return await asyncio.gather(
*[_execute_task(task) for task in tasks_to_execute]
)
else:
return await asyncio.gather(*tasks_to_execute)

# outputs: List[Any] = asyncio.run(_gather())
return await _gather()
Loading

0 comments on commit 41430ef

Please sign in to comment.