From db823804dddbf1ba85fb4f473e1a5ff2ab08ab08 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 15 Apr 2024 11:41:28 +0800 Subject: [PATCH 1/3] feat: Support retry for 'Chat Data' --- dbgpt/_private/config.py | 9 ++ dbgpt/app/scene/base_chat.py | 100 +++++++++--------- .../scene/chat_db/auto_execute/out_parser.py | 23 ++-- dbgpt/app/scene/exceptions.py | 22 ++++ dbgpt/util/retry.py | 51 +++++++++ 5 files changed, 148 insertions(+), 57 deletions(-) create mode 100644 dbgpt/app/scene/exceptions.py create mode 100644 dbgpt/util/retry.py diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 07e1f0a89..b9fa4ea27 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -300,6 +300,15 @@ def __init__(self) -> None: # global dbgpt api key self.API_KEYS = os.getenv("API_KEYS", None) + # Non-streaming scene retries + self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int( + os.getenv("DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE", 1) + ) + # Non-streaming scene parallelism + self.DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE = int( + os.getenv("DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE", 1) + ) + @property def local_db_manager(self) -> "ConnectorManager": from dbgpt.datasource.manages import ConnectorManager diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index 9ed49f623..d36d1d05e 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -21,8 +21,11 @@ from dbgpt.serve.conversation.serve import Serve as ConversationServe from dbgpt.util import get_or_create_event_loop from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async +from dbgpt.util.retry import async_retry from dbgpt.util.tracer import root_tracer, trace +from .exceptions import BaseAppException + logger = logging.getLogger(__name__) CFG = Config() @@ -321,70 +324,67 @@ async def nostream_call(self): "BaseChat.nostream_call", metadata=payload.to_dict() ) logger.info(f"Request: \n{payload}") - ai_response_text = "" payload.span_id = span.span_id try: - with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"): - model_output = await self.call_llm_operator(payload) - - ### output parse - ai_response_text = ( - self.prompt_template.output_parser.parse_model_nostream_resp( - model_output, self.prompt_template.sep - ) + ai_response_text, view_message = await self._no_streaming_call_with_retry( + payload ) - ### model result deal self.current_message.add_ai_message(ai_response_text) - prompt_define_response = ( - self.prompt_template.output_parser.parse_prompt_response( - ai_response_text - ) - ) - metadata = { - "model_output": model_output.to_dict(), - "ai_response_text": ai_response_text, - "prompt_define_response": self._parse_prompt_define_response( - prompt_define_response - ), - } - with root_tracer.start_span("BaseChat.do_action", metadata=metadata): - ### run - result = await blocking_func_to_async( - self._executor, self.do_action, prompt_define_response - ) - - ### llm speaker - speak_to_user = self.get_llm_speak(prompt_define_response) - - # view_message = self.prompt_template.output_parser.parse_view_response( - # speak_to_user, result - # ) - view_message = await blocking_func_to_async( - self._executor, - self.prompt_template.output_parser.parse_view_response, - speak_to_user, - result, - prompt_define_response, - ) - - view_message = view_message.replace("\n", "\\n") self.current_message.add_view_message(view_message) self.message_adjust() - span.end() + except BaseAppException as e: + self.current_message.add_view_message(e.view) + span.end(metadata={"error": str(e)}) except Exception as e: - print(traceback.format_exc()) - logger.error("model response parase faild!" + str(e)) - self.current_message.add_view_message( - f"""ERROR!{str(e)}\n {ai_response_text} """ - ) + view_message = f"ERROR! {str(e)}" + self.current_message.add_view_message(view_message) span.end(metadata={"error": str(e)}) - ### store dialogue + + # Store current conversation await blocking_func_to_async( self._executor, self.current_message.end_current_round ) return self.current_ai_response() + @async_retry( + retries=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE, + parallel_executions=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE, + catch_exceptions=(Exception, BaseAppException), + ) + async def _no_streaming_call_with_retry(self, payload): + with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"): + model_output = await self.call_llm_operator(payload) + + ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp( + model_output, self.prompt_template.sep + ) + prompt_define_response = ( + self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + ) + metadata = { + "model_output": model_output.to_dict(), + "ai_response_text": ai_response_text, + "prompt_define_response": self._parse_prompt_define_response( + prompt_define_response + ), + } + with root_tracer.start_span("BaseChat.do_action", metadata=metadata): + result = await blocking_func_to_async( + self._executor, self.do_action, prompt_define_response + ) + + speak_to_user = self.get_llm_speak(prompt_define_response) + + view_message = await blocking_func_to_async( + self._executor, + self.prompt_template.output_parser.parse_view_response, + speak_to_user, + result, + prompt_define_response, + ) + return ai_response_text, view_message.replace("\n", "\\n") + async def get_llm_response(self): payload = await self._build_model_request() logger.info(f"Request: \n{payload}") diff --git a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py index d2e1eae96..7aad46bd8 100644 --- a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py +++ b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py @@ -9,6 +9,8 @@ from dbgpt.core.interface.output_parser import BaseOutputParser from dbgpt.util.json_utils import serialize +from ...exceptions import AppActionException + CFG = Config() @@ -66,9 +68,10 @@ def parse_view_response(self, speak, data, prompt_response) -> str: param = {} api_call_element = ET.Element("chart-view") err_msg = None + success = False try: if not prompt_response.sql or len(prompt_response.sql) <= 0: - return f"""{speak}""" + raise AppActionException("Can not find sql in response", speak) df = data(prompt_response.sql) param["type"] = prompt_response.display @@ -77,20 +80,26 @@ def parse_view_response(self, speak, data, prompt_response) -> str: df.to_json(orient="records", date_format="iso", date_unit="s") ) view_json_str = json.dumps(param, default=serialize, ensure_ascii=False) + success = True except Exception as e: logger.error("parse_view_response error!" + str(e)) - err_param = {} - err_param["sql"] = f"{prompt_response.sql}" - err_param["type"] = "response_table" + err_param = { + "sql": f"{prompt_response.sql}", + "type": "response_table", + "data": [], + } # err_param["err_msg"] = str(e) - err_param["data"] = [] err_msg = str(e) view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False) # api_call_element.text = view_json_str api_call_element.set("content", view_json_str) result = ET.tostring(api_call_element, encoding="utf-8") - if err_msg: - return f"""{speak} \\n ERROR!{err_msg} \n {result.decode("utf-8")}""" + if not success: + view_content = ( + f'{speak} \\n ERROR!' + f"{err_msg} \n {result.decode('utf-8')}" + ) + raise AppActionException("Generate view content failed", view_content) else: return speak + "\n" + result.decode("utf-8") diff --git a/dbgpt/app/scene/exceptions.py b/dbgpt/app/scene/exceptions.py new file mode 100644 index 000000000..10da0e0db --- /dev/null +++ b/dbgpt/app/scene/exceptions.py @@ -0,0 +1,22 @@ +"""Exceptions for Application.""" +import logging + +logger = logging.getLogger(__name__) + + +class BaseAppException(Exception): + """Base Exception for App""" + + def __init__(self, message: str, view: str): + """Base Exception for App""" + super().__init__(message) + self.message = message + self.view = view + + +class AppActionException(BaseAppException): + """Exception for App Action.""" + + def __init__(self, message: str, view: str): + """Exception for App Action""" + super().__init__(message, view) diff --git a/dbgpt/util/retry.py b/dbgpt/util/retry.py new file mode 100644 index 000000000..c29134c2d --- /dev/null +++ b/dbgpt/util/retry.py @@ -0,0 +1,51 @@ +import asyncio +import logging +import traceback + +logger = logging.getLogger(__name__) + + +def async_retry( + retries: int = 1, parallel_executions: int = 1, catch_exceptions=(Exception,) +): + """Async retry decorator. + + Examples: + .. code-block:: python + + @async_retry(retries=3, parallel_executions=2) + async def my_func(): + # Some code that may raise exceptions + pass + + Args: + retries (int): Number of retries. + parallel_executions (int): Number of parallel executions. + catch_exceptions (tuple): Tuple of exceptions to catch. + """ + + def decorator(func): + async def wrapper(*args, **kwargs): + last_exception = None + for attempt in range(retries): + tasks = [func(*args, **kwargs) for _ in range(parallel_executions)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in results: + if not isinstance(result, Exception): + return result + if isinstance(result, catch_exceptions): + last_exception = result + logger.error( + f"Attempt {attempt + 1} of {retries} failed with error: " + f"{type(result).__name__}, {str(result)}" + ) + logger.debug(traceback.format_exc()) + + logger.info(f"Retrying... (Attempt {attempt + 1} of {retries})") + + raise last_exception # After all retries, raise the last caught exception + + return wrapper + + return decorator From 6c8333144cd402bd11856eac3382eb72cb89ce33 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 15 Apr 2024 12:32:04 +0800 Subject: [PATCH 2/3] docs: Add retry config to .env.template --- .env.template | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.env.template b/.env.template index 64b451265..1375fb258 100644 --- a/.env.template +++ b/.env.template @@ -247,4 +247,13 @@ DBGPT_LOG_LEVEL=INFO #** API_KEYS **# #*******************************************************************# # API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas. -# API_KEYS=dbgpt \ No newline at end of file +# API_KEYS=dbgpt + + +#*******************************************************************# +#** Application Config **# +#*******************************************************************# +# Non-streaming scene retries +DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE=1 +# Non-streaming scene parallelism +DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE=1 \ No newline at end of file From 5396afaa39b5cea22069736206dd0cbc1f4be2be Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 15 Apr 2024 12:35:52 +0800 Subject: [PATCH 3/3] docs: Modify .env.template --- .env.template | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.env.template b/.env.template index 1375fb258..7cfd0ce0f 100644 --- a/.env.template +++ b/.env.template @@ -253,7 +253,7 @@ DBGPT_LOG_LEVEL=INFO #*******************************************************************# #** Application Config **# #*******************************************************************# -# Non-streaming scene retries -DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE=1 -# Non-streaming scene parallelism -DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE=1 \ No newline at end of file +## Non-streaming scene retries +# DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE=1 +## Non-streaming scene parallelism +# DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE=1 \ No newline at end of file