diff --git a/python/__init__.py b/python/__init__.py index 04de5e9f8..432ae0359 100644 --- a/python/__init__.py +++ b/python/__init__.py @@ -178,6 +178,7 @@ def get_default_header(): from appbuilder.core.utils import get_model_list from appbuilder.core.console.appbuilder_client.appbuilder_client import AppBuilderClient +from appbuilder.core.console.appbuilder_client.async_appbuilder_client import AsyncAppBuilderClient from appbuilder.core.console.appbuilder_client.appbuilder_client import AgentBuilder from appbuilder.core.console.appbuilder_client.appbuilder_client import get_app_list, get_all_apps, describe_apps from appbuilder.core.console.knowledge_base.knowledge_base import KnowledgeBase @@ -202,19 +203,20 @@ def get_default_header(): from appbuilder.utils.trace.tracer import AppBuilderTracer, AppbuilderInstrumentor __all__ = [ - 'logger', - 'BadRequestException', - 'ForbiddenException', - 'NotFoundException', - 'PreconditionFailedException', - 'InternalServerErrorException', - 'HTTPConnectionException', - 'AppBuilderServerException', - 'AppbuilderTraceException', - 'AppbuilderTestToolEval', - 'AutomaticTestToolEval', + "logger", + "BadRequestException", + "ForbiddenException", + "NotFoundException", + "PreconditionFailedException", + "InternalServerErrorException", + "HTTPConnectionException", + "AppBuilderServerException", + "AppbuilderTraceException", + "AppbuilderTestToolEval", + "AutomaticTestToolEval", "get_model_list", "AppBuilderClient", + "AsyncAppBuilderClient", "AgentBuilder", "get_app_list", "get_all_apps", @@ -232,5 +234,5 @@ def get_default_header(): "AssistantEventHandler", "AssistantStreamManager", "AppBuilderTracer", - "AppbuilderInstrumentor" + "AppbuilderInstrumentor", ] + __COMPONENTS__ diff --git a/python/core/_client.py b/python/core/_client.py index 9c226dced..83a2c524d 100644 --- a/python/core/_client.py +++ b/python/core/_client.py @@ -21,11 +21,12 @@ import requests from requests.adapters import HTTPAdapter, Retry +from aiohttp import ClientResponse from appbuilder import get_default_header from appbuilder.core._exception import * -from appbuilder.core._session import InnerSession +from appbuilder.core._session import InnerSession, AsyncInnerSession from appbuilder.core.constants import ( GATEWAY_URL, GATEWAY_URL_V2, @@ -100,7 +101,8 @@ def _init_secret_key(self, secret_key: str): secret_key_prefix = os.getenv("SECRET_KEY_PREFIX", SECRET_KEY_PREFIX) if not self.secret_key.startswith(secret_key_prefix): - self.secret_key = "{} {}".format(secret_key_prefix, self.secret_key) + self.secret_key = "{} {}".format( + secret_key_prefix, self.secret_key) logger.debug("AppBuilder Secret key: {}\n".format(self.secret_key)) @@ -181,7 +183,8 @@ def check_console_response(response: requests.Response): data = response.json() if "code" in data and data.get("code") != 0: requestId = __class__.response_request_id(response) - raise AppBuilderServerException(requestId, data["code"], data["message"]) + raise AppBuilderServerException( + requestId, data["code"], data["message"]) def auth_header(self, request_id: Optional[str] = None): r"""auth_header is a helper method return auth info""" @@ -234,6 +237,42 @@ def inner(*args, **kwargs): return inner +class AsyncHTTPClient(HTTPClient): + def __init__(self, secret_key=None, gateway="", gateway_v2=""): + super().__init__(secret_key, gateway, gateway_v2) + self.session = AsyncInnerSession() + + @staticmethod + def check_response_header(response: ClientResponse): + r"""check_response_header is a helper method for check head status . + :param response: requests.Response. + :rtype: + """ + status_code = response.status + if status_code == requests.codes.ok: + return + message = "request_id={} , http status code is {}, body is {}".format( + __class__.response_request_id(response), status_code, response.text + ) + if status_code == requests.codes.bad_request: + raise BadRequestException(message) + elif status_code == requests.codes.forbidden: + raise ForbiddenException(message) + elif status_code == requests.codes.not_found: + raise NotFoundException(message) + elif status_code == requests.codes.precondition_required: + raise PreconditionFailedException(message) + elif status_code == requests.codes.internal_server_error: + raise InternalServerErrorException(message) + else: + raise BaseRPCException(message) + + @staticmethod + def response_request_id(response: ClientResponse): + r"""response_request_id is a helper method to get the unique request id""" + return response.headers.get("X-Appbuilder-Request-Id", "") + + class AssistantHTTPClient(HTTPClient): def service_url(self, sub_path: str, prefix: str = None): """ diff --git a/python/core/_session.py b/python/core/_session.py index be3877e0d..31e398228 100644 --- a/python/core/_session.py +++ b/python/core/_session.py @@ -14,6 +14,8 @@ import requests import json +import aiohttp +from aiohttp import ClientSession, hdrs from appbuilder.utils.logger_util import logger from appbuilder.utils.trace.tracer_wrapper import session_post @@ -72,3 +74,57 @@ def get(self, url, **kwargs): @session_post def put(self, url, data=None, **kwargs): return super().put(url=url, data=data, **kwargs) + + +class AsyncInnerSession(ClientSession): + + def __init__(self, *args, **kwargs): + """ + Initialize inner session. + """ + super(AsyncInnerSession, self).__init__(*args, **kwargs) + + async def build_curl(self, method, url, data=None, json_data=None, **kwargs) -> str: + """ + Generate cURL command from prepared request object. + """ + curl = "curl -X {0} -L '{1}' \\\n".format(method, url) + + headers = kwargs.get("headers", {}) + headers_strs = [ + "-H '{0}: {1}' \\".format(k, v) for k, v in headers.items()] + if headers_strs: + headers_strs[-1] = headers_strs[-1].rstrip(" \\") + curl += "\n".join(headers_strs) + + if data: + try: + body = "'{0}'".format(json.dumps(data, ensure_ascii=False)) + curl += " \\\n-d {0}".format(body) + except: + pass + elif json_data: + body = "'{0}'".format(json.dumps(json_data, ensure_ascii=False)) + curl += " \\\n-d {0}".format(body) + + return curl + + @session_post + async def post(self, url, data=None, json=None, **kwargs): + logger.debug("Curl Command:\n" + await self.build_curl(hdrs.METH_POST, url, data=data, json_data=json, **kwargs) + "\n") + return await super().post(url=url, data=data, json=json, **kwargs) + + @session_post + async def delete(self, url, **kwargs): + logger.debug("Curl Command:\n" + await self.build_curl(hdrs.METH_DELETE, url, **kwargs) + "\n") + return await super().delete(url=url, **kwargs) + + @session_post + async def get(self, url, **kwargs): + logger.debug("Curl Command:\n" + await self.build_curl(hdrs.METH_GET, url, **kwargs) + "\n") + return await super().get(url=url, **kwargs) + + @session_post + async def put(self, url, data=None, **kwargs): + logger.debug("Curl Command:\n" + await self.build_curl(hdrs.METH_PUT, url, data=data, **kwargs) + "\n") + return await super().put(url=url, data=data, **kwargs) diff --git a/python/core/component.py b/python/core/component.py index b2c005b47..23bcbac77 100644 --- a/python/core/component.py +++ b/python/core/component.py @@ -22,7 +22,7 @@ from typing import ( Dict, List, Optional, Any, Generator, Union, AsyncGenerator) from appbuilder.core.utils import ttl_lru_cache -from appbuilder.core._client import HTTPClient +from appbuilder.core._client import HTTPClient, AsyncHTTPClient from appbuilder.core.message import Message @@ -118,18 +118,22 @@ class PlanStep(BaseModel, extra='allow'): arguments: dict = Field(default={}, description="step参数") thought: str = Field(default="", description="step思考结果") + class Plan(BaseModel, extra='allow'): detail: str = Field(default="", description="计划详情") steps: list[PlanStep] = Field(default=[], description="步骤列表") + class FunctionCall(BaseModel, extra='allow'): thought: str = Field(default="", description="思考结果") name: str = Field(default="", description="工具名") arguments: dict = Field(default={}, description="参数列表") - + + class Json(BaseModel, extra='allow'): data: str = Field(default="", description="json数据") + class Content(BaseModel): name: str = Field(default="", description="介绍当前yield内容的阶段名, 使用name的必要条件,是同一组件会输出不同type的content,并且需要加以区分,方便前端渲染与用户展示") @@ -141,10 +145,10 @@ class Content(BaseModel): description="大模型的token用量, ") metrics: dict = Field(default={}, description="耗时、性能、内存等trace及debug所需信息") - type: str = Field(default="text", + type: str = Field(default="text", description="代表event 类型,包括 text、code、files、urls、oral_text、references、image、chart、audio该字段的取值决定了下面text字段的内容结构") - text: Union[Text, Code, Files, Urls, OralText, References, Image, Chart, Audio, Plan, Json, FunctionCall] = Field(default=Text, - description="代表当前 event 元素的内容,每一种 event 对应的 text 结构固定") + text: Union[Text, Code, Files, Urls, OralText, References, Image, Chart, Audio, Plan, Json, FunctionCall] = Field(default=Text, + description="代表当前 event 元素的内容,每一种 event 对应的 text 结构固定") @field_validator('text', mode='before') def set_text(cls, v, values, **kwargs): @@ -180,7 +184,7 @@ class ComponentOutput(BaseModel): role: str = Field(default="tool", description="role是区分当前消息来源的重要字段,对于绝大多数组件而言,都是填写tool,标明role所在的消息来源为组件。部分思考及问答组件,role需要填写为assistant") content: list[Content] = Field(default=[], - description="content是当前组件返回内容的主要payload,List[Content],每个Content Dict 包括了当前输出的一个元素") + description="content是当前组件返回内容的主要payload,List[Content],每个Content Dict 包括了当前输出的一个元素") class Component: @@ -202,6 +206,7 @@ def __init__( secret_key: Optional[str] = None, gateway: str = "", lazy_certification: bool = False, + is_aysnc: bool = False, **kwargs ): r"""Component初始化方法. @@ -219,6 +224,7 @@ def __init__( self.gateway = gateway self._http_client = None self.lazy_certification = lazy_certification + self.is_async = is_aysnc if not self.lazy_certification: self.set_secret_key_and_gateway(self.secret_key, self.gateway) @@ -236,7 +242,10 @@ def set_secret_key_and_gateway(self, secret_key: Optional[str] = None, gateway: """ self.secret_key = secret_key self.gateway = gateway - self._http_client = HTTPClient(self.secret_key, self.gateway) + if self.is_async: + self._http_client = AsyncHTTPClient(self.secret_key, self.gateway) + else: + self._http_client = HTTPClient(self.secret_key, self.gateway) @property def http_client(self): @@ -251,7 +260,11 @@ def http_client(self): """ if self._http_client is None: - self._http_client = HTTPClient(self.secret_key, self.gateway) + if self.is_async: + self._http_client = AsyncHTTPClient( + self.secret_key, self.gateway) + else: + self._http_client = HTTPClient(self.secret_key, self.gateway) return self._http_client def __call__(self, *inputs, **kwargs): @@ -521,7 +534,8 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra elif type == "json": text = {"data": text} else: - raise ValueError("Only when type=text/code/urls/oral_text, string text is allowed! Please give dict text") + raise ValueError( + "Only when type=text/code/urls/oral_text, string text is allowed! Please give dict text") elif isinstance(text, dict): if type == "text": key_list = ["info"] @@ -534,7 +548,8 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra elif type == "files": key_list = ["filename", "url"] elif type == "references": - key_list = ["type", "resource_type", "icon", "site_name", "source", "doc_id", "title", "content", "image_content", "image_url", "video_url"] + key_list = ["type", "resource_type", "icon", "site_name", "source", + "doc_id", "title", "content", "image_content", "image_url", "video_url"] elif type == "image": key_list = ["filename", "url"] elif type == "chart": @@ -551,7 +566,8 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra else: raise ValueError("text must be str or dict") - assert role in ["tool", "assistant"], "role must be 'tool' or 'assistant'" + assert role in [ + "tool", "assistant"], "role must be 'tool' or 'assistant'" result = { "role": role, "content": [{ @@ -564,4 +580,4 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra "metrics": metrics }] } - return ComponentOutput(**result) \ No newline at end of file + return ComponentOutput(**result) diff --git a/python/core/console/appbuilder_client/appbuilder_client.py b/python/core/console/appbuilder_client/appbuilder_client.py index 4814f41d4..5d8779a4a 100644 --- a/python/core/console/appbuilder_client/appbuilder_client.py +++ b/python/core/console/appbuilder_client/appbuilder_client.py @@ -17,7 +17,7 @@ import json import uuid import queue -from typing import Optional,Union +from typing import Optional, Union from appbuilder.core.component import Message, Component from appbuilder.core.manifest.models import Manifest from appbuilder.core.console.appbuilder_client import data_class @@ -81,7 +81,7 @@ def describe_apps( marker: Optional[str] = None, maxKeys: int = 10, secret_key: Optional[str] = None, - gateway: Optional[str] = None + gateway: Optional[str] = None, ) -> list[data_class.AppOverview]: """ 该接口查询用户下状态为已发布的应用列表 @@ -100,9 +100,7 @@ def describe_apps( headers = client.auth_header_v2() headers["Content-Type"] = "application/json" url = client.service_url_v2("/app?Action=DescribeApps") - request = data_class.DescribeAppsRequest( - MaxKeys=maxKeys, Marker=marker - ) + request = data_class.DescribeAppsRequest(MaxKeys=maxKeys, Marker=marker) response = client.session.post( url=url, json=request.model_dump(), @@ -225,7 +223,8 @@ def upload_local_file(self, conversation_id, local_file_path: str) -> str: """ if len(conversation_id) == 0: raise ValueError( - "conversation_id is empty, you can run self.create_conversation to get a conversation_id") + "conversation_id is empty, you can run self.create_conversation to get a conversation_id" + ) filepath = os.path.abspath(local_file_path) if not os.path.exists(filepath): @@ -247,17 +246,19 @@ def upload_local_file(self, conversation_id, local_file_path: str) -> str: return resp.id @client_run_trace - def run(self, conversation_id: str, - query: str = "", - file_ids: list = [], - stream: bool = False, - tools: list[Union[data_class.Tool,Manifest]]= None, - tool_outputs: list[data_class.ToolOutput] = None, - tool_choice: data_class.ToolChoice = None, - end_user_id: str = None, - action: data_class.Action = None, - **kwargs - ) -> Message: + def run( + self, + conversation_id: str, + query: str = "", + file_ids: list = [], + stream: bool = False, + tools: list[Union[data_class.Tool, Manifest]] = None, + tool_outputs: list[data_class.ToolOutput] = None, + tool_choice: data_class.ToolChoice = None, + end_user_id: str = None, + action: data_class.Action = None, + **kwargs, + ) -> Message: r"""运行智能体应用 Args: @@ -283,7 +284,8 @@ def run(self, conversation_id: str, if query == "" and (tool_outputs is None or len(tool_outputs) == 0): raise ValueError( - "AppBuilderClient Run API: query and tool_outputs cannot both be empty") + "AppBuilderClient Run API: query and tool_outputs cannot both be empty" + ) req = data_class.AppBuilderClientRequest( app_id=self.app_id, @@ -313,18 +315,20 @@ def run(self, conversation_id: str, data = response.json() resp = data_class.AppBuilderClientResponse(**data) out = data_class.AppBuilderClientAnswer() - _transform(resp, out) + AppBuilderClient._transform(resp, out) return Message(content=out) - def run_with_handler(self, - conversation_id: str, - query: str = "", - file_ids: list = [], - tools: list[Union[data_class.Tool,Manifest]] = None, - stream: bool = False, - event_handler=None, - action=None, - **kwargs): + def run_with_handler( + self, + conversation_id: str, + query: str = "", + file_ids: list = [], + tools: list[Union[data_class.Tool, Manifest]] = None, + stream: bool = False, + event_handler=None, + action=None, + **kwargs, + ): r"""运行智能体应用,并通过事件处理器处理事件 Args: @@ -350,20 +354,22 @@ def run_with_handler(self, tools=tools, stream=stream, action=action, - **kwargs + **kwargs, ) return event_handler - def run_multiple_dialog_with_handler(self, - conversation_id: str, - queries: iter = None, - file_ids: iter = None, - tools: iter = None, - stream: bool = False, - event_handler=None, - actions: iter = None, - **kwargs): + def run_multiple_dialog_with_handler( + self, + conversation_id: str, + queries: iter = None, + file_ids: iter = None, + tools: iter = None, + stream: bool = False, + event_handler=None, + actions: iter = None, + **kwargs, + ): r"""运行智能体应用,并通过事件处理器处理事件 Args: @@ -415,7 +421,7 @@ def run_multiple_dialog_with_handler(self, event_handler.reset_state() @staticmethod - def _iterate_events(request_id, events) -> data_class.AppBuilderClientAnswer: + def _iterate_events(request_id, events): for event in events: try: data = event.data @@ -429,7 +435,7 @@ def _iterate_events(request_id, events) -> data_class.AppBuilderClientAnswer: ) inp = data_class.AppBuilderClientResponse(**data) out = data_class.AppBuilderClientAnswer() - _transform(inp, out) + AppBuilderClient._transform(inp, out) yield out @staticmethod @@ -441,6 +447,24 @@ def _check_console_response(request_id: str, data): service_err_message="message={}".format(data["message"]), ) + @staticmethod + def _transform( + inp: data_class.AppBuilderClientResponse, out: data_class.AppBuilderClientAnswer + ): + out.answer = inp.answer + for ev in inp.content: + event = data_class.Event( + code=ev.event_code, + message=ev.event_message, + status=ev.event_status, + event_type=ev.event_type, + content_type=ev.content_type, + detail=ev.outputs, + usage=ev.usage, + tool_calls=ev.tool_calls, + ) + out.events.append(event) + class AgentBuilder(AppBuilderClient): r"""AgentBuilder是继承自AppBuilderClient的一个子类,用于构建和管理智能体应用。 @@ -464,6 +488,7 @@ class AgentBuilder(AppBuilderClient): print(message.content) """ + @deprecated( reason="AgentBuilder is deprecated, please use AppBuilderClient instead", version="1.0.0", @@ -481,21 +506,3 @@ def __init__(self, app_id: str): """ super().__init__(app_id) - - -def _transform( - inp: data_class.AppBuilderClientResponse, out: data_class.AppBuilderClientAnswer -): - out.answer = inp.answer - for ev in inp.content: - event = data_class.Event( - code=ev.event_code, - message=ev.event_message, - status=ev.event_status, - event_type=ev.event_type, - content_type=ev.content_type, - detail=ev.outputs, - usage=ev.usage, - tool_calls=ev.tool_calls, - ) - out.events.append(event) diff --git a/python/core/console/appbuilder_client/async_appbuilder_client.py b/python/core/console/appbuilder_client/async_appbuilder_client.py new file mode 100644 index 000000000..f6f010f2d --- /dev/null +++ b/python/core/console/appbuilder_client/async_appbuilder_client.py @@ -0,0 +1,291 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +from typing import Union +from aiohttp import FormData +from appbuilder.core.component import Message, Component +from appbuilder.core.console.appbuilder_client import data_class, AppBuilderClient +from appbuilder.core.manifest.models import Manifest +from appbuilder.core._exception import AppBuilderServerException +from appbuilder.utils.sse_util import AsyncSSEClient + + +class AsyncAppBuilderClient(Component): + def __init__(self, app_id, **kwargs): + super().__init__(is_aysnc=True, **kwargs) + if (not isinstance(app_id, str)) or len(app_id) == 0: + raise ValueError( + "app_id must be a str, and length is bigger then zero," + "please go to official website which is 'https://cloud.baidu.com/product/AppBuilder'" + " to get a valid app_id after your application is published." + ) + self.app_id = app_id + + async def create_conversation(self) -> str: + r"""异步创建会话并返回会话ID + + 会话ID在服务端用于上下文管理、绑定会话文档等,如需开始新的会话,请创建并使用新的会话ID + + Args: + 无 + + Returns: + response (str): 唯一会话ID + + """ + headers = self.http_client.auth_header_v2() + headers["Content-Type"] = "application/json" + url = self.http_client.service_url_v2("/app/conversation") + response = await self.http_client.session.post( + url, headers=headers, json={"app_id": self.app_id}, timeout=None + ) + self.http_client.check_response_header(response) + data = await response.json() + resp = data_class.CreateConversationResponse(**data) + return resp.conversation_id + + async def run( + self, + conversation_id: str, + query: str = "", + file_ids: list = [], + stream: bool = False, + tools: list[Union[data_class.Tool, Manifest]] = None, + tool_outputs: list[data_class.ToolOutput] = None, + tool_choice: data_class.ToolChoice = None, + end_user_id: str = None, + action: data_class.Action = None, + **kwargs, + ) -> Message: + r"""异步运行智能体应用 + + Args: + query (str): query内容 + conversation_id (str): 唯一会话ID,如需开始新的会话,请使用self.create_conversation创建新的会话 + file_ids(list[str]): 文件ID列表 + stream (bool): 为True时,流式返回,需要将message.content.answer拼接起来才是完整的回答;为False时,对应非流式返回 + tools(list[Union[data_class.Tool,Manifest]]): 一个Tool或Manifest组成的列表,其中每个Tool(Manifest)对应一个工具的配置, 默认为None + tool_outputs(list[data_class.ToolOutput]): 工具输出列表,格式为list[ToolOutput], ToolOutputd内容为本地的工具执行结果,以自然语言/json dump str描述,默认为None + tool_choice(data_class.ToolChoice): 控制大模型使用组件的方式,默认为None + end_user_id (str): 用户ID,用于区分不同用户 + action(data_class.Action): 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 + kwargs: 其他参数 + + Returns: + message (Message): 对话结果,一个Message对象,使用message.content获取内容。 + """ + + if len(conversation_id) == 0: + raise ValueError( + "conversation_id is empty, you can run self.create_conversation to get a conversation_id" + ) + + if query == "" and (tool_outputs is None or len(tool_outputs) == 0): + raise ValueError( + "AppBuilderClient Run API: query and tool_outputs cannot both be empty" + ) + + req = data_class.AppBuilderClientRequest( + app_id=self.app_id, + conversation_id=conversation_id, + query=query, + stream=True if stream else False, + file_ids=file_ids, + tools=tools, + tool_outputs=tool_outputs, + tool_choice=tool_choice, + end_user_id=end_user_id, + action=action, + ) + + headers = self.http_client.auth_header_v2() + headers["Content-Type"] = "application/json" + url = self.http_client.service_url_v2("/app/conversation/runs") + response = await self.http_client.session.post( + url, headers=headers, json=req.model_dump(), timeout=None + ) + self.http_client.check_response_header(response) + request_id = self.http_client.response_request_id(response) + if stream: + client = AsyncSSEClient(response) + return Message(content=self._iterate_events(request_id, client.events())) + else: + data = await response.json() + resp = data_class.AppBuilderClientResponse(**data) + out = data_class.AppBuilderClientAnswer() + AppBuilderClient._transform(resp, out) + return Message(content=out) + + async def upload_local_file(self, conversation_id, local_file_path: str) -> str: + r"""异步运行,上传文件并将文件与会话ID进行绑定,后续可使用该文件ID进行对话,目前仅支持上传xlsx、jsonl、pdf、png等文件格式 + + 该接口用于在对话中上传文件供大模型处理,文件的有效期为7天并且不超过对话的有效期。一次只能上传一个文件。 + + Args: + conversation_id (str) : 会话ID + local_file_path (str) : 本地文件路径 + + Returns: + response (str): 唯一文件ID + + """ + if len(conversation_id) == 0: + raise ValueError( + "conversation_id is empty, you can run self.create_conversation to get a conversation_id" + ) + + filepath = os.path.abspath(local_file_path) + if not os.path.exists(filepath): + raise FileNotFoundError(f"{filepath} does not exist") + multipart_form_data = FormData() + multipart_form_data.add_field( + name="file", + value=open(local_file_path, "rb"), + filename=os.path.basename(local_file_path), + ) + multipart_form_data.add_field(name="app_id", value=self.app_id) + multipart_form_data.add_field( + name="conversation_id", value=conversation_id) + + headers = self.http_client.auth_header_v2() + url = self.http_client.service_url_v2("/app/conversation/file/upload") + response = await self.http_client.session.post( + url, data=multipart_form_data, headers=headers + ) + self.http_client.check_response_header(response) + data = await response.json() + resp = data_class.FileUploadResponse(**data) + return resp.id + + async def run_with_handler( + self, + conversation_id: str, + query: str = "", + file_ids: list = [], + tools: list[Union[data_class.Tool, Manifest]] = None, + stream: bool = False, + event_handler=None, + action=None, + **kwargs, + ): + r"""异步运行智能体应用,并通过事件处理器处理事件 + + Args: + conversation_id (str): 唯一会话ID,如需开始新的会话,请使用self.create_conversation创建新的会话 + query (str): 查询字符串 + file_ids (list): 文件ID列表 + tools(list[Union[data_class.Tool,Manifest]], 可选): 一个Tool或Manifest组成的列表,其中每个Tool(Manifest)对应一个工具的配置, 默认为None + stream (bool): 是否流式响应 + event_handler (EventHandler): 事件处理器 + action(data_class.Action) 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 + + kwargs: 其他参数 + + Returns: + EventHandler: 事件处理器 + """ + assert event_handler is not None, "event_handler is None" + await event_handler.init( + appbuilder_client=self, + conversation_id=conversation_id, + query=query, + file_ids=file_ids, + tools=tools, + stream=stream, + action=action, + **kwargs, + ) + + return event_handler + + async def run_multiple_dialog_with_handler( + self, + conversation_id: str, + queries: iter = None, + file_ids: iter = None, + tools: iter = None, + stream: bool = False, + event_handler=None, + actions: iter = None, + **kwargs, + ): + r"""运行智能体应用,并通过事件处理器处理事件 + + Args: + conversation_id (str): 唯一会话ID,如需开始新的会话,请使用self.create_conversation创建新的会话 + queries (iter): 查询字符串可迭代对象 + file_ids (iter): 文件ID列表 + tools(iter, 可选): 一个Tool或Manifest组成的列表,其中每个Tool(Manifest)对应一个工具的配置, 默认为None + stream (bool): 是否流式响应 + event_handler (EventHandler): 事件处理器 + actions(iter) 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 + + kwargs: 其他参数 + Returns: + EventHandler: 事件处理器 + """ + assert event_handler is not None, "event_handler is None" + assert queries is not None, "queries is None" + + iter_queries = iter(queries) + iter_file_ids = iter(file_ids) if file_ids else iter([]) + iter_tools = iter(tools) if tools else iter([]) + iter_actions = iter(actions) if actions else iter([]) + + for index, query in enumerate(iter_queries): + file_id = next(iter_file_ids, None) + tool = next(iter_tools, None) + action = next(iter_actions, None) + + if index == 0: + await event_handler.init( + appbuilder_client=self, + conversation_id=conversation_id, + query=query, + file_ids=file_id, + tools=tool, + stream=stream, + action=action, + **kwargs, + ) + yield event_handler + else: + await event_handler.new_dialog( + query=query, + file_ids=file_id, + tools=tool, + stream=stream, + action=action, + ) + yield event_handler + await event_handler.reset_state() + + @staticmethod + async def _iterate_events(request_id, events): + async for event in events: + try: + data = event.data + if len(data) == 0: + data = event.raw + data = json.loads(data) + except json.JSONDecodeError as e: + raise AppBuilderServerException( + request_id=request_id, + message="json decoder failed {}".format(str(e)), + ) + inp = data_class.AppBuilderClientResponse(**data) + out = data_class.AppBuilderClientAnswer() + AppBuilderClient._transform(inp, out) + yield out diff --git a/python/core/console/appbuilder_client/async_event_handler.py b/python/core/console/appbuilder_client/async_event_handler.py new file mode 100644 index 000000000..afa93eef9 --- /dev/null +++ b/python/core/console/appbuilder_client/async_event_handler.py @@ -0,0 +1,459 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from appbuilder.utils.logger_util import logger +from appbuilder.core.console.appbuilder_client import data_class + + +class AppBuilderClientRunContext(object): + def __init__(self) -> None: + """ + 初始化方法。 + + Args: + 无参数。 + + Returns: + None + + """ + self.current_event = None + self.current_tool_calls = None + self.current_status = None + self.need_tool_submit = False + self.is_complete = False + self.current_thought = "" + + +class AsyncAppBuilderEventHandler(object): + def __init__(self): + pass + + async def init( + self, + appbuilder_client, + conversation_id, + query, + file_ids=None, + tools=None, + stream: bool = False, + event_handler=None, + action=None, + **kwargs + ): + """ + 初始化类实例并设置相关参数。 + + Args: + appbuilder_client (object): AppBuilder客户端实例对象。 + conversation_id (str): 对话ID。 + query (str): 用户输入的查询语句。 + file_ids (list, optional): 文件ID列表,默认为None。 + tools (list, optional): 工具列表,默认为None。 + stream (bool, optional): 是否使用流式处理,默认为False。 + event_handler (callable, optional): 事件处理函数,默认为None。 + action (object, optional): 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 + **kwargs: 其他可选参数。 + + Returns: + None + + """ + self._appbuilder_client = appbuilder_client + self._conversation_id = conversation_id + self._query = query + self._file_ids = file_ids + self._tools = tools + self._stream = stream + self._event_handler = event_handler + self._kwargs = kwargs + self._is_complete = False + self._need_tool_call = False + self._last_tool_output = None + self._action = action + + self._iterator = ( + self.__run_process__() + if not self._stream + else self.__stream_run_process__() + ) + + async def __run_process__(self): + """ + 运行进程,并在每次执行后生成结果。 + + Args: + 无参数。 + + Returns: + Generator: 生成器,每次执行后返回结果。 + + """ + while not self._is_complete: + if not self._need_tool_call: + res = await self._run() + await self.__event_process__(res) + else: + res = await self._submit_tool_output() + await self.__event_process__(res) + yield res + if self._need_tool_call and self._is_complete: + await self.reset_state() + + async def __async_run_process__(self): + """ + 异步运行进程,并在每次执行后生成结果 + + Args: + 无参数 + + Returns: + Generator[Any, None, None]: 生成器,每次执行后返回结果 + """ + while not self._is_complete: + if not self._need_tool_call: + res = await self._run() + self.__event_process__(res) + else: + res = await self._submit_tool_output() + self.__event_process__(res) + yield res + if self._need_tool_call and self._is_complete: + self.reset_state() + + async def __event_process__(self, run_response): + """ + 处理事件响应。 + + Args: + run_response (RunResponse): 运行时响应对象。 + + Returns: + None + + Raises: + ValueError: 当解析事件时发生异常或工具输出为空时。 + """ + try: + event = run_response.content.events[-1] + except Exception as e: + raise ValueError(e) + + event_status = event.status + + if event.status == "success": + self._is_complete = True + elif event.status == "interrupt": + self._need_tool_call = True + + context_func_map = { + "preparing": self.preparing, + "running": self.running, + "error": self.error, + "done": self.done, + "interrupt": self.interrupt, + "success": self.success, + } + + run_context = AppBuilderClientRunContext() + await self._update_run_context(run_context, run_response.content) + await self.handle_event_type(run_context, run_response.content) + await self.handle_content_type(run_context, run_response.content) + if event_status in context_func_map: + func = context_func_map[event_status] + func_res = await func(run_context, run_response.content) + + if event_status == "interrupt": + assert isinstance(func_res, list) + if len(func_res) == 0: + raise ValueError("Tool output is empty") + else: + if not isinstance(func_res[0], data_class.ToolOutput): + try: + check_tool_output = data_class.ToolOutput(**func_res[0]) + except Exception as e: + logger.error( + "func interrupt's output should be list[ToolOutput] or list[dict(can be trans to ToolOutput)]" + ) + raise ValueError(e) + self._last_tool_output = func_res + else: + logger.warning( + "Unknown status: {}, response data: {}".format( + event_status, run_response + ) + ) + + async def __stream_run_process__(self): + """ + 异步流式运行处理函数 + + Args: + 无参数 + + Returns: + Generator[Any, None, None]: 返回处理结果的生成器 + """ + while not self._is_complete: + if not self._need_tool_call: + res = await self._run() + else: + res = await self._submit_tool_output() + async for msg in self.__stream_event_process__(res): + yield msg + + async def __stream_event_process__(self, run_response): + """ + 处理流事件,并调用对应的方法 + + Args: + run_response: 包含流事件信息的响应对象 + + Returns: + None + + Raises: + ValueError: 当处理事件时发生异常或中断时工具输出为空时 + """ + async for msg in run_response.content: + if len(msg.events) == 0: + continue + try: + event = msg.events[-1] + except Exception as e: + raise ValueError(e) + + event_status = event.status + + if event.status == "success": + self._is_complete = True + elif event.status == "interrupt": + self._need_tool_call = True + + context_func_map = { + "preparing": self.preparing, + "running": self.running, + "error": self.error, + "done": self.done, + "interrupt": self.interrupt, + "success": self.success, + } + + run_context = AppBuilderClientRunContext() + await self._update_run_context(run_context, msg) + await self.handle_event_type(run_context, msg) + await self.handle_content_type(run_context, msg) + if event_status in context_func_map: + func = context_func_map[event_status] + func_res = await func(run_context, msg) + + if event_status == "interrupt": + assert isinstance(func_res, list) + if len(func_res) == 0: + raise ValueError("Tool output is empty") + else: + if not isinstance(func_res[0], data_class.ToolOutput): + try: + check_tool_output = data_class.ToolOutput(**func_res[0]) + except Exception as e: + logger.info( + "func interrupt's output should be list[ToolOutput] or list[dict(can be trans to ToolOutput)]" + ) + raise ValueError(e) + self._last_tool_output = func_res + else: + logger.warning( + "Unknown status: {}, response data: {}".format( + event_status, run_response + ) + ) + + yield msg + + async def _update_run_context(self, run_context, run_response): + """ + 更新运行上下文。 + + Args: + run_context (dict): 运行上下文字典。 + run_response (object): 运行响应对象。 + + Returns: + None + + """ + run_context.current_event = run_response.events[-1] + run_context.current_tool_calls = run_context.current_event.tool_calls + run_context.current_status = run_context.current_event.status + run_context.need_tool_submit = run_context.current_status == "interrupt" + run_context.is_complete = run_context.current_status == "success" + try: + run_context.current_thought = ( + run_context.current_event.detail.get("text", {}) + .get("function_call", {}) + .get("thought", "") + ) + except Exception as e: + pass + + async def _run(self): + res = await self._appbuilder_client.run( + conversation_id=self._conversation_id, + query=self._query, + file_ids=self._file_ids, + stream=self._stream, + tools=self._tools, + action=self._action, + ) + return res + + async def _submit_tool_output(self): + assert self._last_tool_output is not None + res = await self._appbuilder_client.run( + conversation_id=self._conversation_id, + file_ids=self._file_ids, + stream=self._stream, + tool_outputs=self._last_tool_output, + ) + return res + + async def __anext__(self): + return await self._iterator.__anext__() + + async def __aiter__(self): + async for item in self._iterator: + yield item + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if exc_type is not None: + raise exc_val + + return + + async def reset_state(self): + """ + 重置该对象的状态,将所有实例变量设置为默认值。 + + Args: + 无 + + Returns: + 无 + + """ + self._appbuilder_client = None + self._conversation_id = None + self._query = None + self._file_ids = None + self._tools = None + self._stream = False + self._event_handler = None + self._kwargs = None + self._last_tool_output = None + self._is_complete = False + self._need_tool_call = False + self._iterator = None + + async def new_dialog( + self, + query=None, + file_ids=None, + tools=None, + action=None, + stream: bool = None, + event_handler=None, + **kwargs + ): + """ + 重置handler部分参数,用于复用该handler进行多轮对话。 + + Args: + query (str): 用户输入的查询语句。 + file_ids (list, optional): 文件ID列表,默认为None。 + tools (list, optional): 工具列表,默认为None。 + stream (bool, optional): 是否使用流式处理,默认为False。 + action (object, optional): 对话时要进行的特殊操作。如回复工作流agent中“信息收集节点“的消息。 + event_handler (callable, optional): 事件处理函数,默认为None。 + **kwargs: 其他可选参数。 + + Returns: + None + + """ + self._query = query or self._query + self._stream = stream or self._stream + + self._file_ids = file_ids + self._tools = tools + self._event_handler = event_handler + self._kwargs = kwargs + self._action = action + + # 重置部分状态 + self._is_complete = False + self._need_tool_call = False + self._last_tool_output = None + self._iterator = ( + self.__run_process__() + if not self._stream + else self.__stream_run_process__() + ) + + async def until_done(self): + """ + 迭代并遍历内部迭代器中的所有元素,直到迭代器耗尽。 + + Args: + 无参数。 + + Returns: + 无返回值。 + + """ + async for _ in self._iterator: + pass + + async def handle_content_type(self, run_context, run_response): + # 用户可重载该方法,用于处理不同类型的content_type + pass + + async def handle_event_type(self, run_context, run_response): + # 用户可重载该方法,用于处理不同类型的event_type + pass + + async def interrupt(self, run_context, run_response): + # 用户可重载该方法,当event_status为interrupt时,会调用该方法 + pass + + async def preparing(self, run_context, run_response): + # 用户可重载该方法,当event_status为preparing时,会调用该方法 + pass + + async def running(self, run_context, run_response): + # 用户可重载该方法,当event_status为running时,会调用该方法 + pass + + async def error(self, run_context, run_response): + # 用户可重载该方法,当event_status为error时,会调用该方法 + pass + + async def done(self, run_context, run_response): + # 用户可重载该方法,当event_status为done时,会调用该方法 + pass + + async def success(self, run_context, run_response): + # 用户可重载该方法,当event_status为success时,会调用该方法 + pass diff --git a/python/tests/test_async_appbuilder_client.py b/python/tests/test_async_appbuilder_client.py new file mode 100644 index 000000000..407a7a2d1 --- /dev/null +++ b/python/tests/test_async_appbuilder_client.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import asyncio +import appbuilder + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestAppBuilderClientAsync(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + 无参数,默认值为空。 + + Returns: + 无返回值,方法中执行了环境变量的赋值操作。 + """ + self.app_id = "fb64d96b-f828-4385-ba1d-835298d635a9" + + def test_async_run_stream(self): + appbuilder.logger.setLoglevel("ERROR") + + async def agent_run(client, conversation_id, text): + ans = await client.run(conversation_id, text, stream=True) + async for data in ans.content: + print(data) + + async def agent_handle(): + client = appbuilder.AsyncAppBuilderClient(self.app_id) + conversation_id = await client.create_conversation() + task1 = asyncio.create_task( + agent_run(client, conversation_id, "最早的邮展")) + task2 = asyncio.create_task( + agent_run(client, conversation_id, "最早的漫展")) + await asyncio.gather(task1, task2) + await client.http_client.session.close() + + loop = asyncio.get_event_loop() + loop.run_until_complete(agent_handle()) + + def test_async_run(self): + appbuilder.logger.setLoglevel("ERROR") + + async def agent_run(client, conversation_id, text): + ans = await client.run(conversation_id, text, stream=False) + print(ans.content.answer) + + async def agent_handle(): + client = appbuilder.AsyncAppBuilderClient(self.app_id) + conversation_id = await client.create_conversation() + await client.upload_local_file(conversation_id, "./data/qa_appbuilder_client_demo.pdf") + task1 = asyncio.create_task( + agent_run(client, conversation_id, "最早的邮展")) + task2 = asyncio.create_task( + agent_run(client, conversation_id, "最早的漫展")) + await asyncio.gather(task1, task2) + await client.http_client.session.close() + + loop = asyncio.get_event_loop() + loop.run_until_complete(agent_handle()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_async_appbuilder_client_chatflow.py b/python/tests/test_async_appbuilder_client_chatflow.py new file mode 100644 index 000000000..d6ebe6896 --- /dev/null +++ b/python/tests/test_async_appbuilder_client_chatflow.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import asyncio +import unittest +import appbuilder +from appbuilder.core.console.appbuilder_client.async_event_handler import ( + AsyncAppBuilderEventHandler, +) + + +class MyEventHandler(AsyncAppBuilderEventHandler): + def __init__(self): + super().__init__() + self.interrupt_ids = [] + + async def handle_content_type(self, run_context, run_response): + interrupt_event_id = None + event = run_response.events[-1] + if event.content_type == "chatflow_interrupt": + interrupt_event_id = event.detail.get("interrupt_event_id") + if interrupt_event_id is not None: + self.interrupt_ids.append(interrupt_event_id) + + def _create_action(self): + if len(self.interrupt_ids) == 0: + return None + event_id = self.interrupt_ids.pop() + return { + "action_type": "resume", + "parameters": {"interrupt_event": {"id": event_id, "type": "chat"}}, + } + + async def run(self, query=None): + await super().new_dialog( + query=query, + action=self._create_action(), + ) + + def gen_action(self): + while True: + yield self._create_action() + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestAppBuilderClientChatflow(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + 无参数,默认值为空。 + + Returns: + 无返回值,方法中执行了环境变量的赋值操作。 + """ + self.app_id = "4403205e-fb83-4fac-96d8-943bdb63796f" + + def test_chatflow(self): + appbuilder.logger.setLoglevel("DEBUG") + + async def agent_handle(): + client = appbuilder.AsyncAppBuilderClient(self.app_id) + conversation_id = await client.create_conversation() + event_handler = MyEventHandler() + await event_handler.init( + appbuilder_client=client, + conversation_id=conversation_id, + stream=False, + query="查天气", + ) + async for data in event_handler: + pass + await event_handler.run( + query="查航班", + ) + async for data in event_handler: + pass + await event_handler.run( + query="CA1234", + ) + async for data in event_handler: + pass + await event_handler.run( + query="北京的", + ) + async for data in event_handler: + pass + + await client.http_client.session.close() + + loop = asyncio.get_event_loop() + loop.run_until_complete(agent_handle()) + + def test_chatflow_stream(self): + appbuilder.logger.setLoglevel("DEBUG") + + async def agent_handle(): + client = appbuilder.AsyncAppBuilderClient(self.app_id) + conversation_id = await client.create_conversation() + event_handler = MyEventHandler() + await event_handler.init( + appbuilder_client=client, + conversation_id=conversation_id, + stream=True, + query="查天气", + ) + async for data in event_handler: + pass + await event_handler.run( + query="查航班", + ) + async for data in event_handler: + pass + await event_handler.run( + query="CA1234", + ) + async for data in event_handler: + pass + await event_handler.run( + query="北京的", + ) + async for data in event_handler: + pass + + await client.http_client.session.close() + + loop = asyncio.get_event_loop() + loop.run_until_complete(agent_handle()) + + def test_chatflow_stream(self): + appbuilder.logger.setLoglevel("DEBUG") + + async def agent_handle(): + client = appbuilder.AsyncAppBuilderClient(self.app_id) + conversation_id = await client.create_conversation() + event_handler = MyEventHandler() + await event_handler.init( + appbuilder_client=client, + conversation_id=conversation_id, + stream=True, + query="查天气", + ) + async for data in event_handler: + pass + await event_handler.run( + query="查航班", + ) + async for data in event_handler: + pass + await event_handler.run( + query="CA1234", + ) + async for data in event_handler: + pass + await event_handler.run( + query="北京的", + ) + async for data in event_handler: + pass + + await client.http_client.session.close() + + loop = asyncio.get_event_loop() + loop.run_until_complete(agent_handle()) + + def test_chatflow_multiple_dialog(self): + appbuilder.logger.setLoglevel("DEBUG") + + async def agent_handle(): + client = appbuilder.AsyncAppBuilderClient(self.app_id) + conversation_id = await client.create_conversation() + queries = ["查天气", "查航班", "CA1234", "北京的"] + event_handler = MyEventHandler() + event_handler = client.run_multiple_dialog_with_handler( + conversation_id=conversation_id, + queries=queries, + event_handler=event_handler, + stream=False, + actions=event_handler.gen_action(), + ) + async for data in event_handler: + async for answer in data: + print(answer) + + await client.http_client.session.close() + + loop = asyncio.get_event_loop() + loop.run_until_complete(agent_handle()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_async_appbuilder_client_follow_up_query.py b/python/tests/test_async_appbuilder_client_follow_up_query.py new file mode 100644 index 000000000..0c7f54ae3 --- /dev/null +++ b/python/tests/test_async_appbuilder_client_follow_up_query.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import asyncio +import appbuilder +from appbuilder.core.console.appbuilder_client.async_event_handler import ( + AsyncAppBuilderEventHandler, +) + + +class MyEventHandler(AsyncAppBuilderEventHandler): + def __init__(self): + super().__init__() + self.follow_up_queries = [] + + async def handle_content_type(self, run_context, run_response): + event = run_response.events[-1] + if event.content_type == "json" and event.event_type == "FollowUpQuery": + follow_up_queries = event.detail.get("json").get("follow_up_querys") + self.follow_up_queries.extend(follow_up_queries) + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestAppBuilderClientAsync(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + 无参数,默认值为空。 + + Returns: + 无返回值,方法中执行了环境变量的赋值操作。 + """ + self.app_id = "fb64d96b-f828-4385-ba1d-835298d635a9" + + def test_async_run_stream(self): + appbuilder.logger.setLoglevel("ERROR") + async def agent_handle(): + client = appbuilder.AsyncAppBuilderClient(self.app_id) + conversation_id = await client.create_conversation() + event_handler = MyEventHandler() + with await client.run_with_handler( + conversation_id = conversation_id, + query = "你能做什么", + stream=True, + event_handler=event_handler, + ) as run: + await run.until_done() + + print(event_handler.follow_up_queries) + assert len(event_handler.follow_up_queries) > 0 + await client.http_client.session.close() + + loop = asyncio.get_event_loop() + loop.run_until_complete(agent_handle()) + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_async_appbuilder_client_toolcall.py b/python/tests/test_async_appbuilder_client_toolcall.py new file mode 100644 index 000000000..00322bf3b --- /dev/null +++ b/python/tests/test_async_appbuilder_client_toolcall.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import appbuilder +import asyncio +import os +from appbuilder.core.console.appbuilder_client.async_event_handler import ( + AsyncAppBuilderEventHandler, +) + + +class MyEventHandler(AsyncAppBuilderEventHandler): + def get_current_weather(self, location=None, unit="摄氏度"): + return "{} 的温度是 {} {}".format(location, 20, unit) + + async def interrupt(self, run_context, run_response): + thought = run_context.current_thought + # 绿色打印 + print("\033[1;32m", "-> Agent 中间思考: ", thought, "\033[0m") + + tool_output = [] + for tool_call in run_context.current_tool_calls: + tool_call_id = tool_call.id + tool_res = self.get_current_weather(**tool_call.function.arguments) + # 蓝色打印 + print("\033[1;34m", "-> 本地ToolCall结果: ", tool_res, "\033[0m\n") + tool_output.append( + {"tool_call_id": tool_call_id, "output": tool_res}) + return tool_output + + async def success(self, run_context, run_response): + print("\n\033[1;31m", "-> Agent 非流式回答: ", + run_response.answer, "\033[0m") + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "") +class TestAgentRuntime(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + 无参数,默认值为空。 + + Returns: + 无返回值,方法中执行了环境变量的赋值操作。 + """ + self.app_id = "b2a972c5-e082-46e5-b313-acbf51792422" + + def test_appbuilder_client_tool_call(self): + # 如果app_id为空,则跳过单测执行, 避免单测因配置无效而失败 + """ + 如果app_id为空,则跳过单测执行, 避免单测因配置无效而失败 + + Args: + self (unittest.TestCase): unittest的TestCase对象 + + Raises: + None: 如果app_id不为空,则不会引发任何异常 + unittest.SkipTest (optional): 如果app_id为空,则跳过单测执行 + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "仅支持中国城市的天气查询,参数location为中国城市名称,其他国家城市不支持天气查询", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "城市名,举例:北京", + }, + "unit": {"type": "string", "enum": ["摄氏度", "华氏度"]}, + }, + "required": ["location", "unit"], + }, + }, + } + ] + + appbuilder.logger.setLoglevel("ERROR") + + async def agent_run(client, conversation_id, query): + with await client.run_with_handler( + conversation_id=conversation_id, + query=query, + tools=tools, + event_handler=MyEventHandler(), + ) as run: + await run.until_done() + + async def agent_handle(): + client = appbuilder.AsyncAppBuilderClient(self.app_id) + conversation_id = await client.create_conversation() + task1 = asyncio.create_task( + agent_run(client, conversation_id, "北京的天气怎么样")) + task2 = asyncio.create_task( + agent_run(client, conversation_id, "上海的天气怎么样")) + await asyncio.gather(task1, task2) + + await client.http_client.session.close() + + loop = asyncio.get_event_loop() + loop.run_until_complete(agent_handle()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_core_client.py b/python/tests/test_core_client.py index 1cd654408..b1b8f52ad 100644 --- a/python/tests/test_core_client.py +++ b/python/tests/test_core_client.py @@ -15,100 +15,142 @@ import unittest import json -from appbuilder.core._client import HTTPClient -from appbuilder.core._exception import * +from appbuilder.core._client import HTTPClient, AsyncHTTPClient +from appbuilder.core._exception import * # 创建一个response类,模拟requests.Response + + class Response: def __init__(self, status_code, headers, text): self.status_code = status_code self.headers = headers self.text = text - + + def json(self): + return json.loads(self.text) + + +class AsyncResponse: + def __init__(self, status_code, headers, text): + self.status = status_code + self.headers = headers + self.text = text + def json(self): - return json.loads(self.text) + return json.loads(self.text) + class TestCoreClient(unittest.TestCase): def setUp(self): # 保存原始环境变量 self.original_appbuilder_token = os.getenv('APPBUILDER_TOKEN') self.original_gateway_url = os.getenv('GATEWAY_URL') - + def tearDown(self): # 恢复环境变量 if self.original_appbuilder_token is None: os.unsetenv('APPBUILDER_TOKEN') else: os.environ['APPBUILDER_TOKEN'] = self.original_appbuilder_token - + if self.original_gateway_url is None: os.unsetenv('GATEWAY_URL') else: os.environ['GATEWAY_URL'] = self.original_gateway_url - + def test_core_client_init_non_APPBUILDER_TOKEN(self): os.environ['APPBUILDER_TOKEN'] = '' with self.assertRaises(ValueError): HTTPClient() - + def test_core_client_init_non_GATEWAY_URL(self): - os.environ['GATEWAY_URL'] = 'test' - hp=HTTPClient() + os.environ['GATEWAY_URL'] = 'test' + hp = HTTPClient() assert hp.gateway.startswith('https://') - + def test_core_client_check_response_header(self): # 测试各种response报错 response = Response( status_code=400, - headers={'Content-Type': 'application/json'} , + headers={'Content-Type': 'application/json'}, text='{"code": 0, "message": "success"}' - ) + ) with self.assertRaises(BadRequestException): HTTPClient.check_response_header(response) - + response.status_code = 403 with self.assertRaises(ForbiddenException): HTTPClient.check_response_header(response) - + response.status_code = 404 with self.assertRaises(NotFoundException): HTTPClient.check_response_header(response) - + response.status_code = 428 with self.assertRaises(PreconditionFailedException): HTTPClient.check_response_header(response) - + response.status_code = 500 with self.assertRaises(InternalServerErrorException): HTTPClient.check_response_header(response) - + response.status_code = 201 with self.assertRaises(BaseRPCException): HTTPClient.check_response_header(response) - + + def test_core_client_check_async_response_header(self): + # 测试各种response报错 + response = AsyncResponse( + status_code=400, + headers={'Content-Type': 'application/json'}, + text='{"code": 0, "message": "success"}' + ) + with self.assertRaises(BadRequestException): + AsyncHTTPClient.check_response_header(response) + + response.status = 403 + with self.assertRaises(ForbiddenException): + AsyncHTTPClient.check_response_header(response) + + response.status = 404 + with self.assertRaises(NotFoundException): + AsyncHTTPClient.check_response_header(response) + + response.status = 428 + with self.assertRaises(PreconditionFailedException): + AsyncHTTPClient.check_response_header(response) + + response.status = 500 + with self.assertRaises(InternalServerErrorException): + AsyncHTTPClient.check_response_header(response) + + response.status = 201 + with self.assertRaises(BaseRPCException): + AsyncHTTPClient.check_response_header(response) + def test_core_client_check_response_json(self): - data={ + data = { 'code': 0, 'message': 'test', - 'requestId':'test' - } + 'requestId': 'test' + } with self.assertRaises(AppBuilderServerException): HTTPClient.check_response_json(data) - + def test_core_check_console_response(self): response = Response( status_code=400, - headers={'Content-Type': 'application/json'} , + headers={'Content-Type': 'application/json'}, text=json.dumps({ 'code': 1, 'message': 'test', - 'requestId':'test' + 'requestId': 'test' }) - ) + ) with self.assertRaises(AppBuilderServerException): HTTPClient.check_console_response(response) - - + + if __name__ == '__main__': unittest.main() - \ No newline at end of file diff --git a/python/tests/test_core_session.py b/python/tests/test_core_session.py new file mode 100644 index 000000000..c9d671b0e --- /dev/null +++ b/python/tests/test_core_session.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import appbuilder +import asyncio +import aiohttp +from unittest.mock import patch, MagicMock +from appbuilder.core._session import AsyncInnerSession + +class TestCoreSession(unittest.TestCase): + @patch("aiohttp.ClientSession.put") + def test_async_session_get(self, mock_put): + async def demo(): + return {"status": 200} + + async def async_magic(): + pass + + async def get_demo(): + mock_put.return_value.__aenter__.return_value.json = await demo() + MagicMock.__await__ = lambda x: async_magic().__await__() + session = AsyncInnerSession() + await session.get("http://www.baidu.com") + await session.put("https://example.com") + + loop = asyncio.get_event_loop() + loop.run_until_complete(get_demo()) + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_utils.py b/python/tests/test_utils.py index 8bce87418..b2a4b0011 100644 --- a/python/tests/test_utils.py +++ b/python/tests/test_utils.py @@ -13,9 +13,10 @@ # limitations under the License. import os import unittest +import asyncio from unittest.mock import MagicMock -from appbuilder.utils.sse_util import SSEClient,Event +from appbuilder.utils.sse_util import SSEClient,AsyncSSEClient, Event from appbuilder.utils.model_util import RemoteModel,Models from appbuilder.utils.logger_util import LoggerWithLoggerId,_setup_logging,logger from threading import current_thread @@ -55,6 +56,33 @@ def test_sse_util_SSEClient(self): # test_close sse_client.close() + + def test_sse_util_AsyncSSEClient(self): + async def mock_client(): + mock_event_source = MagicMock() + mock_event_source.__iter__.return_value = iter([ + b'data: Test event 1\n\n', + b'data: Last incomplete event' + ]) + sse_client = AsyncSSEClient(mock_event_source) + event_generator = sse_client._read() + async for data in event_generator: + pass + + # test_events + mock_event_source.__aiter__.return_value = iter([ + b': Test event 1\n\n', + b'test: Test event 2\n\n', + b'data:Testevent3\n\n', + b'data\n\n', + b'event:Testevent5\n\n', + ]) + sse_client = AsyncSSEClient(mock_event_source) + async for event in sse_client.events(): + pass + + loop = asyncio.get_event_loop() + loop.run_until_complete(mock_client()) def test_sse_util_SSEClient_DEBUG(self): logger.setLoglevel("DEBUG") diff --git a/python/utils/sse_util.py b/python/utils/sse_util.py index 027923c8e..a1984a397 100644 --- a/python/utils/sse_util.py +++ b/python/utils/sse_util.py @@ -16,19 +16,21 @@ """ from appbuilder.utils.logger_util import logger import logging +import aiohttp + class SSEClient: """ 一个简易的SSE Client,用于接收服务端发送的SSE事件。 """ - def __init__(self, event_source, char_enc='utf-8'): + def __init__(self, event_source, char_enc="utf-8"): """ 通过现有的事件源初始化 SSE 客户端。 事件源应为二进制流,并具有 close() 方法。 这通常是实现 io.BinaryIOBase 的东西,比如 httplib 或 urllib3HTTPResponse 对象。 """ - logger.info(f'Initialized SSE client from event source {event_source}') + logger.info(f"Initialized SSE client from event source {event_source}") self._event_source = event_source self._char_enc = char_enc @@ -38,23 +40,23 @@ def _read(self): 不幸的是,有些服务器可能会决定在响应中将事件分解为多个HTTP块。 因此,有必要正确地将连续的响应块缝合在一起,并找到SSE分隔符(空的新行),以生成完整、正确的事件块。 """ - data = b'' + data = b"" for chunk in self._event_source: for line in chunk.splitlines(True): data += line - if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): yield data - data = b'' + data = b"" if data: yield data def events(self): """ 从给定的输入流中读取 Server-Side-Event (SSE) 数据,并生成解析后的 Event 对象。 - + Args: 无 - + Returns: generator: 解析后的 Event 对象的生成器。 """ @@ -66,34 +68,36 @@ def events(self): line = line.decode(self._char_enc) # Lines starting with a separator are comments and are to be # ignored. - if not line.strip() or line.startswith(':'): + if not line.strip() or line.startswith(":"): continue logger.debug(f"raw line: {line}") - data = line.split(':', 1) + data = line.split(":", 1) field = data[0] # Ignore unknown fields. if field not in event.__dict__: event.raw += line - logger.info(f'Saw invalid field {field} while parsing Server Side Event') + logger.info( + f"Saw invalid field {field} while parsing Server Side Event" + ) continue if len(data) > 1: # From the spec: # "If value starts with a single U+0020 SPACE character, # remove it from value." - if data[1].startswith(' '): + if data[1].startswith(" "): value = data[1][1:] else: value = data[1] else: # If no value is present after the separator, # assume an empty value. - value = '' + value = "" # The data field may come over multiple lines and their values # are concatenated with each other. - if field == 'data': - event.__dict__[field] += value + '\n' - event.raw += value + '\n' + if field == "data": + event.__dict__[field] += value + "\n" + event.raw += value + "\n" else: event.__dict__[field] = value event.raw += value @@ -107,15 +111,15 @@ def events(self): continue else: # If the data field ends with a newline, remove it. - if event.data.endswith('\n'): + if event.data.endswith("\n"): event.data = event.data[0:-1] # Empty event names default to 'message' - event.event = event.event or 'message' + event.event = event.event or "message" # Dispatch the event if logger.getEffectiveLevel() == logging.DEBUG: - logger.debug(f'Dispatching {event.debug_str}...') + logger.debug(f"Dispatching {event.debug_str}...") else: - logger.info(f'Dispatching {event}...') + logger.info(f"Dispatching {event}...") yield event def close(self): @@ -125,11 +129,96 @@ def close(self): self._event_source.close() +class AsyncSSEClient: + """ + 一个简易的SSE Client,用于接收服务端发送的SSE事件。 + """ + def __init__(self, response, char_enc='utf-8'): + """ + 通过现有的事件源response初始化 SSE 客户端。 + response应为aiohttp.ClientResponse实例 + """ + self._response = response + self._char_enc = char_enc + + async def _read(self): + """ + 读取传入的事件源流并生成事件块。 + """ + data = b'' + async for chunk in self._response.content.iter_any(): + for line in chunk.splitlines(True): + data += line + if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + yield data + data = b'' + if data: + yield data + + async def events(self): + """ + 从给定的输入流中读取 Server-Side-Event (SSE) 数据,并生成解析后的 Event 对象。 + Returns: + generator: 解析后的 Event 对象的生成器。 + """ + async for chunk in self._read(): + event = Event() + # Split before decoding so splitlines() only uses \r and \n + for line in chunk.splitlines(): + # Decode the line. + line = line.decode(self._char_enc) + # Lines starting with a separator are comments and are to be ignored. + if not line.strip() or line.startswith(':'): + continue + + data = line.split(':', 1) + field = data[0] + # Ignore unknown fields. + if field not in event.__dict__: + event.raw += line + continue + + if len(data) > 1: + # From the spec: + # "If value starts with a single U+0020 SPACE character, + # remove it from value." + if data[1].startswith(' '): + value = data[1][1:] + else: + value = data[1] + else: + # If no value is present after the separator, + # assume an empty value. + value = '' + + # The data field may come over multiple lines and their values are concatenated with each other. + if field == 'data': + event.__dict__[field] += value + '\n' + event.raw += value + '\n' + else: + event.__dict__[field] = value + event.raw += value + + # Events with no data are not dispatched. + if not event.data: + continue + + # If the data field ends with a newline, remove it. + if event.data.endswith('\n'): + event.data = event.data[0:-1] + + # Empty event names default to 'message' + event.event = event.event or 'message' + + yield event + + class Event(object): """ 事件流中的事件。 """ - def __init__(self, id=None, event='message', data='', retry=None): + + def __init__(self, id=None, event="message", data="", retry=None): self.id = id self.event = event self.data = data @@ -137,30 +226,30 @@ def __init__(self, id=None, event='message', data='', retry=None): self.raw = "" def __str__(self): - s = f'{self.event} event' + s = f"{self.event} event" if self.id: - s += f' #{self.id}' + s += f" #{self.id}" if self.data: - s += f', {len(self.data)} byte' + s += f", {len(self.data)} byte" else: - s += ', no data' + s += ", no data" if self.retry: - s += f', retry in {self.retry} ms' + s += f", retry in {self.retry} ms" return s @property def debug_str(self): - s = f'{self.event} event' + s = f"{self.event} event" if self.id: - s += f' #{self.id}' + s += f" #{self.id}" if self.data: - s += f', {len(self.data)} byte, DATA<<{self.data}>>' + s += f", {len(self.data)} byte, DATA<<{self.data}>>" else: - s += ', no data' + s += ", no data" if self.raw: - s += f', RAW<<{self.raw}>>' + s += f", RAW<<{self.raw}>>" else: - s += ', no raw' + s += ", no raw" if self.retry: - s += f', retry in {self.retry} ms' + s += f", retry in {self.retry} ms" return s