Skip to content

Commit

Permalink
修复异步toolcall单测并发interrupt的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
userpj committed Dec 20, 2024
1 parent 9a95c5e commit 8726783
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 44 deletions.
6 changes: 3 additions & 3 deletions python/core/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(self, secret_key=None, gateway="", gateway_v2=""):
self.session = AsyncInnerSession()

@staticmethod
def check_response_header(response: ClientResponse):
async def check_response_header(response: ClientResponse):
r"""check_response_header is a helper method for check head status .
:param response: requests.Response.
:rtype:
Expand All @@ -252,7 +252,7 @@ def check_response_header(response: ClientResponse):
if status_code == requests.codes.ok:
return
message = "request_id={} , http status code is {}, body is {}".format(
__class__.response_request_id(response), status_code, response.text
await __class__.response_request_id(response), status_code, await response.text()
)
if status_code == requests.codes.bad_request:
raise BadRequestException(message)
Expand All @@ -268,7 +268,7 @@ def check_response_header(response: ClientResponse):
raise BaseRPCException(message)

@staticmethod
def response_request_id(response: ClientResponse):
async def response_request_id(response: ClientResponse):
r"""response_request_id is a helper method to get the unique request id"""
return response.headers.get("X-Appbuilder-Request-Id", "")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def create_conversation(self) -> str:
response = await self.http_client.session.post(
url, headers=headers, json={"app_id": self.app_id}, timeout=None
)
self.http_client.check_response_header(response)
await self.http_client.check_response_header(response)
data = await response.json()
resp = data_class.CreateConversationResponse(**data)
return resp.conversation_id
Expand Down Expand Up @@ -116,8 +116,8 @@ async def run(
response = await self.http_client.session.post(
url, headers=headers, json=req.model_dump(), timeout=None
)
self.http_client.check_response_header(response)
request_id = self.http_client.response_request_id(response)
await self.http_client.check_response_header(response)
request_id = await self.http_client.response_request_id(response)
if stream:
client = AsyncSSEClient(response)
return Message(content=self._iterate_events(request_id, client.events()))
Expand Down Expand Up @@ -164,7 +164,7 @@ async def upload_local_file(self, conversation_id, local_file_path: str) -> str:
response = await self.http_client.session.post(
url, data=multipart_form_data, headers=headers
)
self.http_client.check_response_header(response)
await self.http_client.check_response_header(response)
data = await response.json()
resp = data_class.FileUploadResponse(**data)
return resp.id
Expand Down
6 changes: 3 additions & 3 deletions python/core/console/appbuilder_client/async_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ async def __async_run_process__(self):
while not self._is_complete:
if not self._need_tool_call:
res = await self._run()
self.__event_process__(res)
await self.__event_process__(res)
else:
res = await self._submit_tool_output()
self.__event_process__(res)
await self.__event_process__(res)
yield res
if self._need_tool_call and self._is_complete:
self.reset_state()
await self.reset_state()

async def __event_process__(self, run_response):
"""
Expand Down
13 changes: 7 additions & 6 deletions python/tests/test_async_appbuilder_client_toolcall.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ async def interrupt(self, run_context, run_response):
tool_call_id = tool_call.id
tool_res = self.get_current_weather(**tool_call.function.arguments)
# 蓝色打印
print("\033[1;34m", "-> 本地ToolCall结果: ", tool_res, "\033[0m\n")
print("\033[1;34m", "-> 本地ToolCallId: ", tool_call_id, "\033[0m")
print("\033[1;34m", "-> ToolCall结果: ", tool_res, "\033[0m\n")
tool_output.append(
{"tool_call_id": tool_call_id, "output": tool_res})
return tool_output
Expand Down Expand Up @@ -92,9 +93,10 @@ def test_appbuilder_client_tool_call(self):
}
]

appbuilder.logger.setLoglevel("ERROR")
appbuilder.logger.setLoglevel("DEBUG")

async def agent_run(client, conversation_id, query):
async def agent_run(client, query):
conversation_id = await client.create_conversation()
with await client.run_with_handler(
conversation_id=conversation_id,
query=query,
Expand All @@ -105,11 +107,10 @@ async def agent_run(client, conversation_id, query):

async def agent_handle():
client = appbuilder.AsyncAppBuilderClient(self.app_id)
conversation_id = await client.create_conversation()
task1 = asyncio.create_task(
agent_run(client, conversation_id, "北京的天气怎么样"))
agent_run(client, "北京的天气怎么样"))
task2 = asyncio.create_task(
agent_run(client, conversation_id, "上海的天气怎么样"))
agent_run(client, "上海的天气怎么样"))
await asyncio.gather(task1, task2)

await client.http_client.session.close()
Expand Down
60 changes: 32 additions & 28 deletions python/tests/test_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import unittest
import json
import asyncio

from appbuilder.core._client import HTTPClient, AsyncHTTPClient
from appbuilder.core._exception import *
Expand Down Expand Up @@ -100,34 +101,37 @@ def test_core_client_check_response_header(self):
HTTPClient.check_response_header(response)

def test_core_client_check_async_response_header(self):
# 测试各种response报错
response = AsyncResponse(
status_code=400,
headers={'Content-Type': 'application/json'},
text='{"code": 0, "message": "success"}'
)
with self.assertRaises(BadRequestException):
AsyncHTTPClient.check_response_header(response)

response.status = 403
with self.assertRaises(ForbiddenException):
AsyncHTTPClient.check_response_header(response)

response.status = 404
with self.assertRaises(NotFoundException):
AsyncHTTPClient.check_response_header(response)

response.status = 428
with self.assertRaises(PreconditionFailedException):
AsyncHTTPClient.check_response_header(response)

response.status = 500
with self.assertRaises(InternalServerErrorException):
AsyncHTTPClient.check_response_header(response)

response.status = 201
with self.assertRaises(BaseRPCException):
AsyncHTTPClient.check_response_header(response)
async def run_test():
# 测试各种response报错
response = AsyncResponse(
status_code=400,
headers={'Content-Type': 'application/json'},
text=lambda:asyncio.sleep(0) or '{"code": 0, "message": "success"}'
)
with self.assertRaises(BadRequestException):
await AsyncHTTPClient.check_response_header(response)

response.status = 403
with self.assertRaises(ForbiddenException):
await AsyncHTTPClient.check_response_header(response)

response.status = 404
with self.assertRaises(NotFoundException):
await AsyncHTTPClient.check_response_header(response)

response.status = 428
with self.assertRaises(PreconditionFailedException):
await AsyncHTTPClient.check_response_header(response)

response.status = 500
with self.assertRaises(InternalServerErrorException):
await AsyncHTTPClient.check_response_header(response)

response.status = 201
with self.assertRaises(BaseRPCException):
await AsyncHTTPClient.check_response_header(response)
loop = asyncio.get_event_loop()
loop.run_until_complete(run_test())

def test_core_client_check_response_json(self):
data = {
Expand Down

0 comments on commit 8726783

Please sign in to comment.