From 1fe77040f03dc41675f1bb3787537a1fcd64871f Mon Sep 17 00:00:00 2001 From: fangweimin Date: Fri, 29 Dec 2023 10:37:38 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=88=97=E8=A1=A8=E8=8E=B7=E5=8F=96=E7=9A=84SDK?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- appbuilder/__init__.py | 3 + appbuilder/core/components/asr/component.py | 7 +- appbuilder/core/utils.py | 41 +++- appbuilder/tests/test_get_model_list.py | 57 +++++ appbuilder/utils/model_util.py | 248 ++++++++++++++++++++ 5 files changed, 352 insertions(+), 4 deletions(-) create mode 100644 appbuilder/tests/test_get_model_list.py create mode 100644 appbuilder/utils/model_util.py diff --git a/appbuilder/__init__.py b/appbuilder/__init__.py index 146f4b289..036146695 100644 --- a/appbuilder/__init__.py +++ b/appbuilder/__init__.py @@ -73,6 +73,9 @@ def check_version(self): from appbuilder.utils.logger_util import logger +from appbuilder.core.utils import get_model_list + + from .core._exception import ( BadRequestException, ForbiddenException, diff --git a/appbuilder/core/components/asr/component.py b/appbuilder/core/components/asr/component.py index da6eca9aa..cec07d27a 100644 --- a/appbuilder/core/components/asr/component.py +++ b/appbuilder/core/components/asr/component.py @@ -63,12 +63,12 @@ def run(self, message: Message, audio_format: str = "pcm", rate: int = 16000, request.cuid = str(uuid.uuid4()) request.dev_pid = "80001" request.speech = inp.raw_audio - response = self._recognize(request) + response = self._recognize(request, timeout, retry) out = ASROutMsg(result=list(response.result)) return Message(content=dict(out)) def _recognize(self, request: ShortSpeechRecognitionRequest, timeout: float = None, - retry: int = 0) -> ShortSpeechRecognitionResponse: + retry: int = 0) -> ShortSpeechRecognitionResponse: """ 使用给定的输入并返回语音识别的结果。 @@ -89,7 +89,8 @@ def _recognize(self, request: ShortSpeechRecognitionRequest, timeout: float = No } if retry != self.http_client.retry.total: self.http_client.retry.total = retry - response = self.http_client.session.post(self.http_client.service_url("/v1/bce/aip_speech/asrpro"), params=params, headers=headers, data=request.speech, timeout=timeout) + response = self.http_client.session.post(self.http_client.service_url("/v1/bce/aip_speech/asrpro"), + params=params, headers=headers, data=request.speech, timeout=timeout) self.http_client.check_response_header(response) data = response.json() self.http_client.check_response_json(data) diff --git a/appbuilder/core/utils.py b/appbuilder/core/utils.py index bce92fc9c..6cd70e6f0 100644 --- a/appbuilder/core/utils.py +++ b/appbuilder/core/utils.py @@ -11,10 +11,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List - +from appbuilder.utils.model_util import GetModelListRequest, Models, map_model_name def utils_get_user_agent(): return 'appbuilder-sdk-python/{}'.format("__version__") + +def get_model_list(secret_key: str = "", apiTypefilter: List[str] = [], is_available: bool = False): + """ + 返回用户的模型列表。 + + 参数: + secret_key(str,可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "")。 + apiTypefilter(List[str], 可选): 根据apiType过滤,["chat", "completions", "embeddings", "text2image"],不填包括所有的。 + is_available(bool, 可选): 返回可用模型列表, 默认返回所有模型。 + + 返回: + List[str]: 模型列表。 + """ + request = GetModelListRequest() + request.apiTypefilter = apiTypefilter + model = Models(secret_key=secret_key) + response = model.list(request) + models = [] + if is_available: + for common_model in response.result.common: + if common_model.chargeStatus == "OPENED": + mapped_name = map_model_name(common_model.name) + models.append(mapped_name) + + for custom_model in response.result.custom: + if custom_model.chargeStatus == "OPENED": + mapped_name = map_model_name(custom_model.name) + models.append(mapped_name) + return models + else: + for common_model in response.result.common: + mapped_name = map_model_name(common_model.name) + models.append(mapped_name) + + for custom_model in response.result.custom: + mapped_name = map_model_name(custom_model.name) + models.append(mapped_name) + return models diff --git a/appbuilder/tests/test_get_model_list.py b/appbuilder/tests/test_get_model_list.py new file mode 100644 index 000000000..993eee6de --- /dev/null +++ b/appbuilder/tests/test_get_model_list.py @@ -0,0 +1,57 @@ +import unittest + +import appbuilder +from appbuilder.utils.model_util import GetModelListRequest, Models, GetModelListResponse + + +class TestModels(unittest.TestCase): + def setUp(self): + """ + 设置环境变量。 + + Args: + None + + Returns: + None. + """ + self.model = Models() + + def test_list(self): + """ + _list方法单测 + + Args: + None + + Returns: + None + + """ + + request = GetModelListRequest() + response = self.model.list(request) + self.assertIsNotNone(response) + self.assertIsInstance(response, GetModelListResponse) + + def test_check_service_error(self): + """ + check_service_error方法单测 + + Args: + None + + Returns: + None + + """ + data = {'error_msg': 'Error', 'error_code': 1} + request_id = "request_id" + with self.assertRaises(appbuilder.AppBuilderServerException): + self.model._check_service_error(request_id, data) + data = {'error_msg': 'No Error', 'error_code': 0} + self.assertIsNone(self.model._check_service_error(request_id, data)) + + +if __name__ == '__main__': + unittest.main() diff --git a/appbuilder/utils/model_util.py b/appbuilder/utils/model_util.py new file mode 100644 index 000000000..c49bba29b --- /dev/null +++ b/appbuilder/utils/model_util.py @@ -0,0 +1,248 @@ +import json +import proto +from typing import Optional, MutableSequence + +import appbuilder +from appbuilder.core._client import HTTPClient + +r"""模型名称到简称的映射. +""" +model_name_mapping = { + "ERNIE-Bot 4.0": "eb-4", + "ERNIE-Bot-8K": "eb-8k", + "ERNIE-Bot": "eb", + "ERNIE-Bot-turbo": "eb-turbo", + "EB-turbo-AppBuilder专用版": "eb-turbo-appbuilder", +} + + +class GetModelListRequest(proto.Message): + r"""获取模型列表请求体 + 参数: + apiTypefilter(str): + 根据apiType过滤,["chat", "completions", "embeddings", "text2image"],不填包括所有的。 + """ + apiTypefilter: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=1 + ) + + +class GetModelListResponse(proto.Message): + r"""获取模型列表返回体 + 参数: + request_id(str): + 网关层的请求ID. + log_id(str): + 请求ID。 + success(bool): + 是否成功的返回。 + error_code(int): + 错误码。 + error_msg(str): + 错误信息。 + result(ModelListResult): + 模型列表。 + """ + request_id: str = proto.Field( + proto.STRING, + number=1, + ) + + log_id: str = proto.Field( + proto.STRING, + number=2, + ) + + success: bool = proto.Field( + proto.BOOL, + number=3, + ) + + error_code: int = proto.Field( + proto.INT32, + number=4, + ) + + error_msg: str = proto.Field( + proto.STRING, + number=5, + ) + result: "ModelListResult" = proto.Field( + proto.MESSAGE, + number=6, + message="ModelListResult", + ) + + +class ModelListResult(proto.Message): + r"""模型列表 + 参数: + common(ModelData): + 预置服务模型信息。 + custom(ModelData): + 自定义服务模型信息。 + """ + common: MutableSequence["ModelData"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="ModelData", + ) + + custom: MutableSequence["ModelData"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="ModelData", + ) + + +class ModelData(proto.Message): + r"""模型基本信息 + 参数: + name(str): + 服务名称。 + url(int): + 服务endpoint。 + apiType(str): + 服务类型:chat、completions、embeddings、text2image。 + chargeStatus(int): + 付费状态。 + versionList(int): + 服务版本列表。 + """ + name: str = proto.Field( + proto.STRING, + number=1, + ) + + url: str = proto.Field( + proto.STRING, + number=2, + ) + + apiType: str = proto.Field( + proto.STRING, + number=3, + ) + chargeStatus: str = proto.Field( + proto.STRING, + number=4, + ) + + versionList: MutableSequence["Version"] = proto.RepeatedField( + proto.MESSAGE, + number=5, + message="Version", + ) + + +class Version(proto.Message): + r"""服务版本 + 参数: + id(str): + 服务版本id,仅自定义服务有该字段。 + aiModelId(str): + 发布该服务版本的模型id,仅自定义服务有该字段。 + aiModelVersionId(str): + 发布该服务版本的模型版本id,仅自定义服务有该字段。 + trainType(str): + 服务基础模型类型。 + serviceStatus(str): + 服务状态。 + """ + id: str = proto.Field( + proto.STRING, + number=1, + ) + aiModelId: str = proto.Field( + proto.STRING, + number=2, + ) + aiModelVersionId: str = proto.Field( + proto.STRING, + number=3, + ) + trainType: str = proto.Field( + proto.STRING, + number=4, + ) + serviceStatus: str = proto.Field( + proto.STRING, + number=5, + ) + + +class Models: + r""" + 模型工具类,提供模型列表接口。 + """ + + def __init__(self, + secret_key: Optional[str] = None, + gateway: str = "" + ): + r"""Models初始化方法. + + 参数: + secret_key(str,可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", ""). + gateway(str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "") + 返回: + 无 + """ + self.http_client = HTTPClient(secret_key, gateway) + + def list(self, request: GetModelListRequest, timeout: float = None, + retry: int = 0) -> GetModelListResponse: + """ + 返回用户的模型列表信息。 + + 参数: + request (obj:`GetModelListRequest`):模型列表查询请求体。 + timeout (float, 可选): 请求的超时时间。 + retry (int, 可选): 请求的重试次数。 + + 返回: + obj:`GetModelListResponse`: 模型列表返回体。 + """ + url = self.http_client.service_url("/v1/bce/wenxinworkshop/service/list") + data = GetModelListRequest.to_json(request) + headers = self.http_client.auth_header() + headers['content-type'] = 'application/json' + if retry != self.http_client.retry.total: + self.http_client.retry.total = retry + response = self.http_client.session.post(url, data=data, headers=headers, timeout=timeout) + self.http_client.check_response_header(response) + data = response.json() + self.http_client.check_response_json(data) + request_id = self.http_client.response_request_id(response) + self.__class__._check_service_error(request_id, data) + response = GetModelListResponse.from_json(payload=json.dumps(data)) + response.request_id = request_id + return response + + @staticmethod + def _check_service_error(request_id: str, data: dict): + r"""服务response参数检查 + + 参数: + data (dict) : body返回 + 返回: + 无 + """ + if "error_code" in data and "error_msg" in data: + if data["error_code"] != 0: + raise appbuilder.AppBuilderServerException( + request_id=request_id, + service_err_code=data["error_code"], + service_err_message=data["error_msg"]) + + +def map_model_name(model_name: str) -> str: + r"""模型名称映射函数 + + 参数: + model_name (str) : 模型名称 + 返回: + str: 映射后的模型名称 + """ + return model_name_mapping.get(model_name, model_name) From f821fec32f541a677a1e6c9421b92cd1c34fa61b Mon Sep 17 00:00:00 2001 From: fangweimin Date: Fri, 29 Dec 2023 11:02:06 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E4=B8=BA=E5=8F=AF=E9=80=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- appbuilder/utils/model_util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/appbuilder/utils/model_util.py b/appbuilder/utils/model_util.py index c49bba29b..b6b86f3eb 100644 --- a/appbuilder/utils/model_util.py +++ b/appbuilder/utils/model_util.py @@ -191,7 +191,7 @@ def __init__(self, """ self.http_client = HTTPClient(secret_key, gateway) - def list(self, request: GetModelListRequest, timeout: float = None, + def list(self, request: GetModelListRequest = None, timeout: float = None, retry: int = 0) -> GetModelListResponse: """ 返回用户的模型列表信息。 @@ -205,6 +205,8 @@ def list(self, request: GetModelListRequest, timeout: float = None, obj:`GetModelListResponse`: 模型列表返回体。 """ url = self.http_client.service_url("/v1/bce/wenxinworkshop/service/list") + if request is None: + request = GetModelListRequest() data = GetModelListRequest.to_json(request) headers = self.http_client.auth_header() headers['content-type'] = 'application/json' From 3666fe56b5a53980af7bb79c07e2cf92d95970b3 Mon Sep 17 00:00:00 2001 From: fangweimin Date: Fri, 29 Dec 2023 11:14:18 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- appbuilder/tests/test_get_model_list.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/appbuilder/tests/test_get_model_list.py b/appbuilder/tests/test_get_model_list.py index 993eee6de..91da93055 100644 --- a/appbuilder/tests/test_get_model_list.py +++ b/appbuilder/tests/test_get_model_list.py @@ -17,9 +17,23 @@ def setUp(self): """ self.model = Models() + def get_model_list(self): + """ + get_model_list方法单测 + + Args: + None + + Returns: + None + + """ + response = appbuilder.get_model_list(apiTypefilter=["chat"]) + self.assertIsNotNone(response) + def test_list(self): """ - _list方法单测 + list方法单测 Args: None From afb0300ccf909c86232a6ccbba58a04423d1ef43 Mon Sep 17 00:00:00 2001 From: fangweimin Date: Wed, 3 Jan 2024 10:47:31 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E5=8F=AF=E7=94=A8=E7=8A=B6=E6=80=81=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- appbuilder/core/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/appbuilder/core/utils.py b/appbuilder/core/utils.py index 6cd70e6f0..92bd5cf5b 100644 --- a/appbuilder/core/utils.py +++ b/appbuilder/core/utils.py @@ -39,12 +39,12 @@ def get_model_list(secret_key: str = "", apiTypefilter: List[str] = [], is_avail models = [] if is_available: for common_model in response.result.common: - if common_model.chargeStatus == "OPENED": + if common_model.chargeStatus in ["OPENED", "FREE"]: mapped_name = map_model_name(common_model.name) models.append(mapped_name) for custom_model in response.result.custom: - if custom_model.chargeStatus == "OPENED": + if custom_model.chargeStatus in ["OPENED", "FREE"]: mapped_name = map_model_name(custom_model.name) models.append(mapped_name) return models From fc8cde8270d1e99a00ea03b7e38a86e3c0358b6d Mon Sep 17 00:00:00 2001 From: fangweimin Date: Wed, 3 Jan 2024 16:50:50 +0800 Subject: [PATCH 5/6] fix review comment --- appbuilder/core/utils.py | 32 ++++++++----------------- appbuilder/tests/test_get_model_list.py | 19 +++++++++++++-- appbuilder/utils/model_util.py | 14 +++++++++++ 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/appbuilder/core/utils.py b/appbuilder/core/utils.py index 92bd5cf5b..e87c1feb2 100644 --- a/appbuilder/core/utils.py +++ b/appbuilder/core/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools from typing import List from appbuilder.utils.model_util import GetModelListRequest, Models, map_model_name @@ -20,7 +21,7 @@ def utils_get_user_agent(): return 'appbuilder-sdk-python/{}'.format("__version__") -def get_model_list(secret_key: str = "", apiTypefilter: List[str] = [], is_available: bool = False): +def get_model_list(secret_key: str = "", api_type_filter: List[str] = [], is_available: bool = False) -> list: """ 返回用户的模型列表。 @@ -30,30 +31,17 @@ def get_model_list(secret_key: str = "", apiTypefilter: List[str] = [], is_avail is_available(bool, 可选): 返回可用模型列表, 默认返回所有模型。 返回: - List[str]: 模型列表。 + list: 模型列表。 """ request = GetModelListRequest() - request.apiTypefilter = apiTypefilter + request.apiTypefilter = api_type_filter model = Models(secret_key=secret_key) response = model.list(request) models = [] - if is_available: - for common_model in response.result.common: - if common_model.chargeStatus in ["OPENED", "FREE"]: - mapped_name = map_model_name(common_model.name) - models.append(mapped_name) - - for custom_model in response.result.custom: - if custom_model.chargeStatus in ["OPENED", "FREE"]: - mapped_name = map_model_name(custom_model.name) - models.append(mapped_name) - return models - else: - for common_model in response.result.common: - mapped_name = map_model_name(common_model.name) - models.append(mapped_name) - - for custom_model in response.result.custom: - mapped_name = map_model_name(custom_model.name) - models.append(mapped_name) + + for model in itertools.chain(response.result.common, response.result.custom): + if is_available and model.chargeStatus not in ["OPENED", "FREE"]: + continue + mapped_name = map_model_name(model.name) + models.append(mapped_name) return models diff --git a/appbuilder/tests/test_get_model_list.py b/appbuilder/tests/test_get_model_list.py index 91da93055..b63730336 100644 --- a/appbuilder/tests/test_get_model_list.py +++ b/appbuilder/tests/test_get_model_list.py @@ -1,5 +1,18 @@ -import unittest +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest import appbuilder from appbuilder.utils.model_util import GetModelListRequest, Models, GetModelListResponse @@ -28,8 +41,10 @@ def get_model_list(self): None """ - response = appbuilder.get_model_list(apiTypefilter=["chat"]) + response = appbuilder.get_model_list(api_type_filter=["chat"]) self.assertIsNotNone(response) + self.assertIsInstance(response, GetModelListResponse) + self.assertTrue(response.success) def test_list(self): """ diff --git a/appbuilder/utils/model_util.py b/appbuilder/utils/model_util.py index b6b86f3eb..9d9f10126 100644 --- a/appbuilder/utils/model_util.py +++ b/appbuilder/utils/model_util.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import proto from typing import Optional, MutableSequence From 7d97965c398c4d646d4f5c68dcc458d4155fcbfd Mon Sep 17 00:00:00 2001 From: fangweimin Date: Wed, 3 Jan 2024 17:44:00 +0800 Subject: [PATCH 6/6] fix review comment 2 --- appbuilder/core/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/appbuilder/core/utils.py b/appbuilder/core/utils.py index e87c1feb2..207180b22 100644 --- a/appbuilder/core/utils.py +++ b/appbuilder/core/utils.py @@ -27,8 +27,8 @@ def get_model_list(secret_key: str = "", api_type_filter: List[str] = [], is_ava 参数: secret_key(str,可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "")。 - apiTypefilter(List[str], 可选): 根据apiType过滤,["chat", "completions", "embeddings", "text2image"],不填包括所有的。 - is_available(bool, 可选): 返回可用模型列表, 默认返回所有模型。 + api_type_filter(List[str], 可选): 根据apiType过滤,["chat", "completions", "embeddings", "text2image"],不填包括所有的。 + is_available(bool, 可选): 是否返回可用模型列表, 默认返回所有模型。 返回: list: 模型列表。