Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

修复异步toolcall单测并发interrupt的问题 #677

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/core/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(self, secret_key=None, gateway="", gateway_v2=""):
self.session = AsyncInnerSession()

@staticmethod
def check_response_header(response: ClientResponse):
async def check_response_header(response: ClientResponse):
r"""check_response_header is a helper method for check head status .
:param response: requests.Response.
:rtype:
Expand All @@ -252,7 +252,7 @@ def check_response_header(response: ClientResponse):
if status_code == requests.codes.ok:
return
message = "request_id={} , http status code is {}, body is {}".format(
__class__.response_request_id(response), status_code, response.text
await __class__.response_request_id(response), status_code, await response.text()
)
if status_code == requests.codes.bad_request:
raise BadRequestException(message)
Expand All @@ -268,7 +268,7 @@ def check_response_header(response: ClientResponse):
raise BaseRPCException(message)

@staticmethod
def response_request_id(response: ClientResponse):
async def response_request_id(response: ClientResponse):
r"""response_request_id is a helper method to get the unique request id"""
return response.headers.get("X-Appbuilder-Request-Id", "")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def create_conversation(self) -> str:
response = await self.http_client.session.post(
url, headers=headers, json={"app_id": self.app_id}, timeout=None
)
self.http_client.check_response_header(response)
await self.http_client.check_response_header(response)
data = await response.json()
resp = data_class.CreateConversationResponse(**data)
return resp.conversation_id
Expand Down Expand Up @@ -116,8 +116,8 @@ async def run(
response = await self.http_client.session.post(
url, headers=headers, json=req.model_dump(), timeout=None
)
self.http_client.check_response_header(response)
request_id = self.http_client.response_request_id(response)
await self.http_client.check_response_header(response)
request_id = await self.http_client.response_request_id(response)
if stream:
client = AsyncSSEClient(response)
return Message(content=self._iterate_events(request_id, client.events()))
Expand Down Expand Up @@ -164,7 +164,7 @@ async def upload_local_file(self, conversation_id, local_file_path: str) -> str:
response = await self.http_client.session.post(
url, data=multipart_form_data, headers=headers
)
self.http_client.check_response_header(response)
await self.http_client.check_response_header(response)
data = await response.json()
resp = data_class.FileUploadResponse(**data)
return resp.id
Expand Down
6 changes: 3 additions & 3 deletions python/core/console/appbuilder_client/async_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ async def __async_run_process__(self):
while not self._is_complete:
if not self._need_tool_call:
res = await self._run()
self.__event_process__(res)
await self.__event_process__(res)
else:
res = await self._submit_tool_output()
self.__event_process__(res)
await self.__event_process__(res)
yield res
if self._need_tool_call and self._is_complete:
self.reset_state()
await self.reset_state()

async def __event_process__(self, run_response):
"""
Expand Down
2 changes: 1 addition & 1 deletion python/tests/component_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_component_white_list():
def get_components(components_list, import_prefix, skip_components):
components = {}
for component in components_list:
if component.__name__ in skip_components:
if component in skip_components:
continue

try:
Expand Down
13 changes: 7 additions & 6 deletions python/tests/test_async_appbuilder_client_toolcall.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ async def interrupt(self, run_context, run_response):
tool_call_id = tool_call.id
tool_res = self.get_current_weather(**tool_call.function.arguments)
# 蓝色打印
print("\033[1;34m", "-> 本地ToolCall结果: ", tool_res, "\033[0m\n")
print("\033[1;34m", "-> 本地ToolCallId: ", tool_call_id, "\033[0m")
print("\033[1;34m", "-> ToolCall结果: ", tool_res, "\033[0m\n")
tool_output.append(
{"tool_call_id": tool_call_id, "output": tool_res})
return tool_output
Expand Down Expand Up @@ -92,9 +93,10 @@ def test_appbuilder_client_tool_call(self):
}
]

appbuilder.logger.setLoglevel("ERROR")
appbuilder.logger.setLoglevel("DEBUG")

async def agent_run(client, conversation_id, query):
async def agent_run(client, query):
conversation_id = await client.create_conversation()
with await client.run_with_handler(
conversation_id=conversation_id,
query=query,
Expand All @@ -105,11 +107,10 @@ async def agent_run(client, conversation_id, query):

async def agent_handle():
client = appbuilder.AsyncAppBuilderClient(self.app_id)
conversation_id = await client.create_conversation()
task1 = asyncio.create_task(
agent_run(client, conversation_id, "北京的天气怎么样"))
agent_run(client, "北京的天气怎么样"))
task2 = asyncio.create_task(
agent_run(client, conversation_id, "上海的天气怎么样"))
agent_run(client, "上海的天气怎么样"))
await asyncio.gather(task1, task2)

await client.http_client.session.close()
Expand Down
60 changes: 32 additions & 28 deletions python/tests/test_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import unittest
import json
import asyncio

from appbuilder.core._client import HTTPClient, AsyncHTTPClient
from appbuilder.core._exception import *
Expand Down Expand Up @@ -100,34 +101,37 @@ def test_core_client_check_response_header(self):
HTTPClient.check_response_header(response)

def test_core_client_check_async_response_header(self):
# 测试各种response报错
response = AsyncResponse(
status_code=400,
headers={'Content-Type': 'application/json'},
text='{"code": 0, "message": "success"}'
)
with self.assertRaises(BadRequestException):
AsyncHTTPClient.check_response_header(response)

response.status = 403
with self.assertRaises(ForbiddenException):
AsyncHTTPClient.check_response_header(response)

response.status = 404
with self.assertRaises(NotFoundException):
AsyncHTTPClient.check_response_header(response)

response.status = 428
with self.assertRaises(PreconditionFailedException):
AsyncHTTPClient.check_response_header(response)

response.status = 500
with self.assertRaises(InternalServerErrorException):
AsyncHTTPClient.check_response_header(response)

response.status = 201
with self.assertRaises(BaseRPCException):
AsyncHTTPClient.check_response_header(response)
async def run_test():
# 测试各种response报错
response = AsyncResponse(
status_code=400,
headers={'Content-Type': 'application/json'},
text=lambda:asyncio.sleep(0) or '{"code": 0, "message": "success"}'
)
with self.assertRaises(BadRequestException):
await AsyncHTTPClient.check_response_header(response)

response.status = 403
with self.assertRaises(ForbiddenException):
await AsyncHTTPClient.check_response_header(response)

response.status = 404
with self.assertRaises(NotFoundException):
await AsyncHTTPClient.check_response_header(response)

response.status = 428
with self.assertRaises(PreconditionFailedException):
await AsyncHTTPClient.check_response_header(response)

response.status = 500
with self.assertRaises(InternalServerErrorException):
await AsyncHTTPClient.check_response_header(response)

response.status = 201
with self.assertRaises(BaseRPCException):
await AsyncHTTPClient.check_response_header(response)
loop = asyncio.get_event_loop()
loop.run_until_complete(run_test())

def test_core_client_check_response_json(self):
data = {
Expand Down
Loading