Skip to content

Commit

Permalink
feat: Support retry for 'Chat Data'
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Apr 15, 2024
1 parent 2e2e120 commit db82380
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 57 deletions.
9 changes: 9 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ def __init__(self) -> None:
# global dbgpt api key
self.API_KEYS = os.getenv("API_KEYS", None)

# Non-streaming scene retries
self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int(
os.getenv("DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE", 1)
)
# Non-streaming scene parallelism
self.DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE = int(
os.getenv("DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE", 1)
)

@property
def local_db_manager(self) -> "ConnectorManager":
from dbgpt.datasource.manages import ConnectorManager
Expand Down
100 changes: 50 additions & 50 deletions dbgpt/app/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.retry import async_retry
from dbgpt.util.tracer import root_tracer, trace

from .exceptions import BaseAppException

logger = logging.getLogger(__name__)
CFG = Config()

Expand Down Expand Up @@ -321,70 +324,67 @@ async def nostream_call(self):
"BaseChat.nostream_call", metadata=payload.to_dict()
)
logger.info(f"Request: \n{payload}")
ai_response_text = ""
payload.span_id = span.span_id
try:
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await self.call_llm_operator(payload)

### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
model_output, self.prompt_template.sep
)
ai_response_text, view_message = await self._no_streaming_call_with_retry(
payload
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
metadata = {
"model_output": model_output.to_dict(),
"ai_response_text": ai_response_text,
"prompt_define_response": self._parse_prompt_define_response(
prompt_define_response
),
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
### run
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)

### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)

# view_message = self.prompt_template.output_parser.parse_view_response(
# speak_to_user, result
# )
view_message = await blocking_func_to_async(
self._executor,
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
prompt_define_response,
)

view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message)
self.message_adjust()

span.end()
except BaseAppException as e:
self.current_message.add_view_message(e.view)
span.end(metadata={"error": str(e)})
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild!" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
view_message = f"<span style='color:red'>ERROR!</span> {str(e)}"
self.current_message.add_view_message(view_message)
span.end(metadata={"error": str(e)})
### store dialogue

# Store current conversation
await blocking_func_to_async(
self._executor, self.current_message.end_current_round
)
return self.current_ai_response()

@async_retry(
retries=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
parallel_executions=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
catch_exceptions=(Exception, BaseAppException),
)
async def _no_streaming_call_with_retry(self, payload):
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await self.call_llm_operator(payload)

ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(
model_output, self.prompt_template.sep
)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
)
metadata = {
"model_output": model_output.to_dict(),
"ai_response_text": ai_response_text,
"prompt_define_response": self._parse_prompt_define_response(
prompt_define_response
),
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)

speak_to_user = self.get_llm_speak(prompt_define_response)

view_message = await blocking_func_to_async(
self._executor,
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
prompt_define_response,
)
return ai_response_text, view_message.replace("\n", "\\n")

async def get_llm_response(self):
payload = await self._build_model_request()
logger.info(f"Request: \n{payload}")
Expand Down
23 changes: 16 additions & 7 deletions dbgpt/app/scene/chat_db/auto_execute/out_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.util.json_utils import serialize

from ...exceptions import AppActionException

CFG = Config()


Expand Down Expand Up @@ -66,9 +68,10 @@ def parse_view_response(self, speak, data, prompt_response) -> str:
param = {}
api_call_element = ET.Element("chart-view")
err_msg = None
success = False
try:
if not prompt_response.sql or len(prompt_response.sql) <= 0:
return f"""{speak}"""
raise AppActionException("Can not find sql in response", speak)

df = data(prompt_response.sql)
param["type"] = prompt_response.display
Expand All @@ -77,20 +80,26 @@ def parse_view_response(self, speak, data, prompt_response) -> str:
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
success = True
except Exception as e:
logger.error("parse_view_response error!" + str(e))
err_param = {}
err_param["sql"] = f"{prompt_response.sql}"
err_param["type"] = "response_table"
err_param = {
"sql": f"{prompt_response.sql}",
"type": "response_table",
"data": [],
}
# err_param["err_msg"] = str(e)
err_param["data"] = []
err_msg = str(e)
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)

# api_call_element.text = view_json_str
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8")
if err_msg:
return f"""{speak} \\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result.decode("utf-8")}"""
if not success:
view_content = (
f'{speak} \\n <span style="color:red">ERROR!</span>'
f"{err_msg} \n {result.decode('utf-8')}"
)
raise AppActionException("Generate view content failed", view_content)
else:
return speak + "\n" + result.decode("utf-8")
22 changes: 22 additions & 0 deletions dbgpt/app/scene/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Exceptions for Application."""
import logging

logger = logging.getLogger(__name__)


class BaseAppException(Exception):
"""Base Exception for App"""

def __init__(self, message: str, view: str):
"""Base Exception for App"""
super().__init__(message)
self.message = message
self.view = view


class AppActionException(BaseAppException):
"""Exception for App Action."""

def __init__(self, message: str, view: str):
"""Exception for App Action"""
super().__init__(message, view)
51 changes: 51 additions & 0 deletions dbgpt/util/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio
import logging
import traceback

logger = logging.getLogger(__name__)


def async_retry(
retries: int = 1, parallel_executions: int = 1, catch_exceptions=(Exception,)
):
"""Async retry decorator.
Examples:
.. code-block:: python
@async_retry(retries=3, parallel_executions=2)
async def my_func():
# Some code that may raise exceptions
pass
Args:
retries (int): Number of retries.
parallel_executions (int): Number of parallel executions.
catch_exceptions (tuple): Tuple of exceptions to catch.
"""

def decorator(func):
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(retries):
tasks = [func(*args, **kwargs) for _ in range(parallel_executions)]
results = await asyncio.gather(*tasks, return_exceptions=True)

for result in results:
if not isinstance(result, Exception):
return result
if isinstance(result, catch_exceptions):
last_exception = result
logger.error(
f"Attempt {attempt + 1} of {retries} failed with error: "
f"{type(result).__name__}, {str(result)}"
)
logger.debug(traceback.format_exc())

logger.info(f"Retrying... (Attempt {attempt + 1} of {retries})")

raise last_exception # After all retries, raise the last caught exception

return wrapper

return decorator

0 comments on commit db82380

Please sign in to comment.