Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support retry for 'Chat Data' #1419

Merged
merged 3 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,13 @@ DBGPT_LOG_LEVEL=INFO
#** API_KEYS **#
#*******************************************************************#
# API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas.
# API_KEYS=dbgpt
# API_KEYS=dbgpt


#*******************************************************************#
#** Application Config **#
#*******************************************************************#
## Non-streaming scene retries
# DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE=1
## Non-streaming scene parallelism
# DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE=1
9 changes: 9 additions & 0 deletions dbgpt/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ def __init__(self) -> None:
# global dbgpt api key
self.API_KEYS = os.getenv("API_KEYS", None)

# Non-streaming scene retries
self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int(
os.getenv("DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE", 1)
)
# Non-streaming scene parallelism
self.DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE = int(
os.getenv("DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE", 1)
)

@property
def local_db_manager(self) -> "ConnectorManager":
from dbgpt.datasource.manages import ConnectorManager
Expand Down
100 changes: 50 additions & 50 deletions dbgpt/app/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.retry import async_retry
from dbgpt.util.tracer import root_tracer, trace

from .exceptions import BaseAppException

logger = logging.getLogger(__name__)
CFG = Config()

Expand Down Expand Up @@ -321,70 +324,67 @@ async def nostream_call(self):
"BaseChat.nostream_call", metadata=payload.to_dict()
)
logger.info(f"Request: \n{payload}")
ai_response_text = ""
payload.span_id = span.span_id
try:
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await self.call_llm_operator(payload)

### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
model_output, self.prompt_template.sep
)
ai_response_text, view_message = await self._no_streaming_call_with_retry(
payload
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
metadata = {
"model_output": model_output.to_dict(),
"ai_response_text": ai_response_text,
"prompt_define_response": self._parse_prompt_define_response(
prompt_define_response
),
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
### run
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)

### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)

# view_message = self.prompt_template.output_parser.parse_view_response(
# speak_to_user, result
# )
view_message = await blocking_func_to_async(
self._executor,
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
prompt_define_response,
)

view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message)
self.message_adjust()

span.end()
except BaseAppException as e:
self.current_message.add_view_message(e.view)
span.end(metadata={"error": str(e)})
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild!" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
view_message = f"<span style='color:red'>ERROR!</span> {str(e)}"
self.current_message.add_view_message(view_message)
span.end(metadata={"error": str(e)})
### store dialogue

# Store current conversation
await blocking_func_to_async(
self._executor, self.current_message.end_current_round
)
return self.current_ai_response()

@async_retry(
retries=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
parallel_executions=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
catch_exceptions=(Exception, BaseAppException),
)
async def _no_streaming_call_with_retry(self, payload):
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await self.call_llm_operator(payload)

ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(
model_output, self.prompt_template.sep
)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
)
metadata = {
"model_output": model_output.to_dict(),
"ai_response_text": ai_response_text,
"prompt_define_response": self._parse_prompt_define_response(
prompt_define_response
),
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)

speak_to_user = self.get_llm_speak(prompt_define_response)

view_message = await blocking_func_to_async(
self._executor,
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
prompt_define_response,
)
return ai_response_text, view_message.replace("\n", "\\n")

async def get_llm_response(self):
payload = await self._build_model_request()
logger.info(f"Request: \n{payload}")
Expand Down
23 changes: 16 additions & 7 deletions dbgpt/app/scene/chat_db/auto_execute/out_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.util.json_utils import serialize

from ...exceptions import AppActionException

CFG = Config()


Expand Down Expand Up @@ -66,9 +68,10 @@ def parse_view_response(self, speak, data, prompt_response) -> str:
param = {}
api_call_element = ET.Element("chart-view")
err_msg = None
success = False
try:
if not prompt_response.sql or len(prompt_response.sql) <= 0:
return f"""{speak}"""
raise AppActionException("Can not find sql in response", speak)

df = data(prompt_response.sql)
param["type"] = prompt_response.display
Expand All @@ -77,20 +80,26 @@ def parse_view_response(self, speak, data, prompt_response) -> str:
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
success = True
except Exception as e:
logger.error("parse_view_response error!" + str(e))
err_param = {}
err_param["sql"] = f"{prompt_response.sql}"
err_param["type"] = "response_table"
err_param = {
"sql": f"{prompt_response.sql}",
"type": "response_table",
"data": [],
}
# err_param["err_msg"] = str(e)
err_param["data"] = []
err_msg = str(e)
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)

# api_call_element.text = view_json_str
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8")
if err_msg:
return f"""{speak} \\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result.decode("utf-8")}"""
if not success:
view_content = (
f'{speak} \\n <span style="color:red">ERROR!</span>'
f"{err_msg} \n {result.decode('utf-8')}"
)
raise AppActionException("Generate view content failed", view_content)
else:
return speak + "\n" + result.decode("utf-8")
22 changes: 22 additions & 0 deletions dbgpt/app/scene/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Exceptions for Application."""
import logging

logger = logging.getLogger(__name__)


class BaseAppException(Exception):
"""Base Exception for App"""

def __init__(self, message: str, view: str):
"""Base Exception for App"""
super().__init__(message)
self.message = message
self.view = view


class AppActionException(BaseAppException):
"""Exception for App Action."""

def __init__(self, message: str, view: str):
"""Exception for App Action"""
super().__init__(message, view)
51 changes: 51 additions & 0 deletions dbgpt/util/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio
import logging
import traceback

logger = logging.getLogger(__name__)


def async_retry(
retries: int = 1, parallel_executions: int = 1, catch_exceptions=(Exception,)
):
"""Async retry decorator.

Examples:
.. code-block:: python

@async_retry(retries=3, parallel_executions=2)
async def my_func():
# Some code that may raise exceptions
pass

Args:
retries (int): Number of retries.
parallel_executions (int): Number of parallel executions.
catch_exceptions (tuple): Tuple of exceptions to catch.
"""

def decorator(func):
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(retries):
tasks = [func(*args, **kwargs) for _ in range(parallel_executions)]
results = await asyncio.gather(*tasks, return_exceptions=True)

for result in results:
if not isinstance(result, Exception):
return result
if isinstance(result, catch_exceptions):
last_exception = result
logger.error(
f"Attempt {attempt + 1} of {retries} failed with error: "
f"{type(result).__name__}, {str(result)}"
)
logger.debug(traceback.format_exc())

logger.info(f"Retrying... (Attempt {attempt + 1} of {retries})")

raise last_exception # After all retries, raise the last caught exception

return wrapper

return decorator
Loading