From 59ac4ee5486627dbd39faf076b9ff801a8b6c4df Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Sat, 4 Nov 2023 18:08:28 +0800 Subject: [PATCH] feat(core): Support pass span id to threadpool --- assets/schema/knowledge_management.sql | 1 + pilot/scene/base_chat.py | 9 +++++- pilot/scene/chat_agent/chat.py | 7 ++++- pilot/scene/chat_dashboard/chat.py | 2 ++ .../chat_excel/excel_analyze/chat.py | 9 +++++- .../chat_excel/excel_learning/chat.py | 2 ++ pilot/scene/chat_db/auto_execute/chat.py | 29 ++++++++++++------- .../scene/chat_db/auto_execute/out_parser.py | 3 ++ pilot/scene/chat_db/professional_qa/chat.py | 2 ++ pilot/scene/chat_execution/chat.py | 2 ++ .../chat_knowledge/inner_db_summary/chat.py | 2 ++ pilot/scene/chat_knowledge/v1/chat.py | 2 ++ pilot/utils/executor_utils.py | 11 +++++-- pilot/utils/tracer/tracer_cli.py | 7 +++-- 14 files changed, 70 insertions(+), 18 deletions(-) diff --git a/assets/schema/knowledge_management.sql b/assets/schema/knowledge_management.sql index e38f731d6..a6f1bc478 100644 --- a/assets/schema/knowledge_management.sql +++ b/assets/schema/knowledge_management.sql @@ -34,6 +34,7 @@ CREATE TABLE `knowledge_document` ( `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', `result` TEXT NULL COMMENT 'knowledge content', `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `summary` LONGTEXT NULL COMMENT 'knowledge summary', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', PRIMARY KEY (`id`), diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 294bd04ca..4529d3cb4 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -145,7 +145,14 @@ async def __call_base(self): ) self.current_message.tokens = 0 if self.prompt_template.template: - current_prompt = self.prompt_template.format(**input_values) + metadata = { + "template_scene": self.prompt_template.template_scene, + "input_values": input_values, + } + with root_tracer.start_span( + "BaseChat.__call_base.prompt_template.format", metadata=metadata + ): + current_prompt = self.prompt_template.format(**input_values) self.current_message.add_system_message(current_prompt) llm_messages = self.generate_llm_messages() diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py index d9a8f60c1..81af1b3b1 100644 --- a/pilot/scene/chat_agent/chat.py +++ b/pilot/scene/chat_agent/chat.py @@ -11,6 +11,7 @@ from .prompt import prompt from pilot.component import ComponentType from pilot.base_modules.agent.controller import ModuleAgent +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -51,6 +52,7 @@ def __init__(self, chat_param: Dict): self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator) + @trace() async def generate_input_values(self) -> Dict[str, str]: input_values = { "user_goal": self.current_user_input, @@ -63,7 +65,10 @@ async def generate_input_values(self) -> Dict[str, str]: def stream_plugin_call(self, text): text = text.replace("\n", " ") - return self.api_call.run(text) + with root_tracer.start_span( + "ChatAgent.stream_plugin_call.api_call", metadata={"text": text} + ): + return self.api_call.run(text) def __list_to_prompt_str(self, list: List) -> str: return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 211aa7c04..6771fb3fc 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -13,6 +13,7 @@ from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -53,6 +54,7 @@ def __load_dashboard_template(self, template_name): data = f.read() return json.loads(data) + @trace() async def generate_input_values(self) -> Dict: try: from pilot.summary.db_summary_client import DBSummaryClient diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index 064e7586c..fefc8142c 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -14,6 +14,7 @@ from pilot.common.path_utils import has_path from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.base_modules.agent.common.schema import Status +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -62,6 +63,7 @@ def _generate_numbered_list(self) -> str: # ] return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) + @trace() async def generate_input_values(self) -> Dict: input_values = { "user_input": self.current_user_input, @@ -88,4 +90,9 @@ async def prepare(self): def stream_plugin_call(self, text): text = text.replace("\n", " ") - return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex) + with root_tracer.start_span( + "ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text} + ): + return self.api_call.run_display_sql( + text, self.excel_reader.get_df_by_sql_ex + ) diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py index f05221eba..7d1730ad0 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py @@ -13,6 +13,7 @@ from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.json_utils.utilities import DateTimeEncoder from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -44,6 +45,7 @@ def __init__( if parent_mode: self.current_message.chat_mode = parent_mode.value() + @trace() async def generate_input_values(self) -> Dict: # colunms, datas = self.excel_reader.get_sample_data() colunms, datas = await blocking_func_to_async( diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index d9b901772..4d4bf3c0c 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -6,6 +6,7 @@ from pilot.configs.config import Config from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -35,10 +36,13 @@ def __init__(self, chat_param: Dict): raise ValueError( f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" ) - - self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) + with root_tracer.start_span( + "ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name} + ): + self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.top_k: int = 200 + @trace() async def generate_input_values(self) -> Dict: """ generate input values @@ -55,13 +59,14 @@ async def generate_input_values(self) -> Dict: # query=self.current_user_input, # topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, # ) - table_infos = await blocking_func_to_async( - self._executor, - client.get_db_summary, - self.db_name, - self.current_user_input, - CFG.KNOWLEDGE_SEARCH_TOP_SIZE, - ) + with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"): + table_infos = await blocking_func_to_async( + self._executor, + client.get_db_summary, + self.db_name, + self.current_user_input, + CFG.KNOWLEDGE_SEARCH_TOP_SIZE, + ) except Exception as e: print("db summary find error!" + str(e)) if not table_infos: @@ -80,4 +85,8 @@ async def generate_input_values(self) -> Dict: def do_action(self, prompt_response): print(f"do_action:{prompt_response}") - return self.database.run(prompt_response.sql) + with root_tracer.start_span( + "ChatWithDbAutoExecute.do_action.run_sql", + metadata=prompt_response.to_dict(), + ): + return self.database.run(prompt_response.sql) diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index 577cac1ef..e583d945a 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -12,6 +12,9 @@ class SqlAction(NamedTuple): sql: str thoughts: Dict + def to_dict(self) -> Dict[str, Dict]: + return {"sql": self.sql, "thoughts": self.thoughts} + logger = logging.getLogger(__name__) diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 5ae76d37d..fde28d91b 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -6,6 +6,7 @@ from pilot.configs.config import Config from pilot.scene.chat_db.professional_qa.prompt import prompt from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -39,6 +40,7 @@ def __init__(self, chat_param: Dict): else len(self.tables) ) + @trace() async def generate_input_values(self) -> Dict: table_info = "" dialect = "mysql" diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index bdd78d7b7..2615918ff 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -6,6 +6,7 @@ from pilot.base_modules.agent.commands.command import execute_command from pilot.base_modules.agent import PluginPromptGenerator from .prompt import prompt +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -50,6 +51,7 @@ def __init__(self, chat_param: Dict): self.plugins_prompt_generator ) + @trace() async def generate_input_values(self) -> Dict: input_values = { "input": self.current_user_input, diff --git a/pilot/scene/chat_knowledge/inner_db_summary/chat.py b/pilot/scene/chat_knowledge/inner_db_summary/chat.py index 07a64aea9..f7c81bd77 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/chat.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/chat.py @@ -4,6 +4,7 @@ from pilot.configs.config import Config from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -31,6 +32,7 @@ def __init__( self.db_input = db_select self.db_summary = db_summary + @trace() async def generate_input_values(self) -> Dict: input_values = { "db_input": self.db_input, diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index d57b32b25..a9c63b268 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -15,6 +15,7 @@ from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.server.knowledge.service import KnowledgeService from pilot.utils.executor_utils import blocking_func_to_async +from pilot.utils.tracer import root_tracer, trace CFG = Config() @@ -92,6 +93,7 @@ def knowledge_reference_call(self, text): """return reference""" return text + f"\n\n{self.parse_source_view(self.sources)}" + @trace() async def generate_input_values(self) -> Dict: if self.space_context: self.prompt_template.template_define = self.space_context["prompt"]["scene"] diff --git a/pilot/utils/executor_utils.py b/pilot/utils/executor_utils.py index 2aac0d04d..26ee3c66e 100644 --- a/pilot/utils/executor_utils.py +++ b/pilot/utils/executor_utils.py @@ -1,5 +1,6 @@ from typing import Callable, Awaitable, Any import asyncio +import contextvars from abc import ABC, abstractmethod from concurrent.futures import Executor, ThreadPoolExecutor from functools import partial @@ -55,6 +56,12 @@ async def blocking_func_to_async( """ if asyncio.iscoroutinefunction(func): raise ValueError(f"The function {func} is not blocking function") + + # This function will be called within the new thread, capturing the current context + ctx = contextvars.copy_context() + + def run_with_context(): + return ctx.run(partial(func, *args, **kwargs)) + loop = asyncio.get_event_loop() - sync_function_noargs = partial(func, *args, **kwargs) - return await loop.run_in_executor(executor, sync_function_noargs) + return await loop.run_in_executor(executor, run_with_context) diff --git a/pilot/utils/tracer/tracer_cli.py b/pilot/utils/tracer/tracer_cli.py index 859fa4022..3fb9cba31 100644 --- a/pilot/utils/tracer/tracer_cli.py +++ b/pilot/utils/tracer/tracer_cli.py @@ -303,8 +303,6 @@ def chat( print(table.get_formatted_string(out_format=output, **out_kwargs)) if sys_table: print(sys_table.get_formatted_string(out_format=output, **out_kwargs)) - if hide_conv: - return if not found_trace_id: print(f"Can't found conversation with trace_id: {trace_id}") @@ -315,9 +313,12 @@ def chat( trace_spans = [s for s in reversed(trace_spans)] hierarchy = _build_trace_hierarchy(trace_spans) if tree: - print("\nInvoke Trace Tree:\n") + print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n") _print_trace_hierarchy(hierarchy) + if hide_conv: + return + trace_spans = _get_ordered_trace_from(hierarchy) table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details") split_long_text = output == "text"