Skip to content

Commit

Permalink
Merge pull request #29 from HorseDream/master
Browse files Browse the repository at this point in the history
增加模型列表获取的SDK
  • Loading branch information
seiriosPlus authored Jan 4, 2024
2 parents af2adc5 + 7d97965 commit d12a6a4
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 4 deletions.
3 changes: 3 additions & 0 deletions appbuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def check_version(self):

from appbuilder.utils.logger_util import logger

from appbuilder.core.utils import get_model_list


from .core._exception import (
BadRequestException,
ForbiddenException,
Expand Down
7 changes: 4 additions & 3 deletions appbuilder/core/components/asr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def run(self, message: Message, audio_format: str = "pcm", rate: int = 16000,
request.cuid = str(uuid.uuid4())
request.dev_pid = "80001"
request.speech = inp.raw_audio
response = self._recognize(request)
response = self._recognize(request, timeout, retry)
out = ASROutMsg(result=list(response.result))
return Message(content=dict(out))

def _recognize(self, request: ShortSpeechRecognitionRequest, timeout: float = None,
retry: int = 0) -> ShortSpeechRecognitionResponse:
retry: int = 0) -> ShortSpeechRecognitionResponse:
"""
使用给定的输入并返回语音识别的结果。
Expand All @@ -89,7 +89,8 @@ def _recognize(self, request: ShortSpeechRecognitionRequest, timeout: float = No
}
if retry != self.http_client.retry.total:
self.http_client.retry.total = retry
response = self.http_client.session.post(self.http_client.service_url("/v1/bce/aip_speech/asrpro"), params=params, headers=headers, data=request.speech, timeout=timeout)
response = self.http_client.session.post(self.http_client.service_url("/v1/bce/aip_speech/asrpro"),
params=params, headers=headers, data=request.speech, timeout=timeout)
self.http_client.check_response_header(response)
data = response.json()
self.http_client.check_response_json(data)
Expand Down
29 changes: 28 additions & 1 deletion appbuilder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import List


from appbuilder.utils.model_util import GetModelListRequest, Models, map_model_name


def utils_get_user_agent():
return 'appbuilder-sdk-python/{}'.format("__version__")


def get_model_list(secret_key: str = "", api_type_filter: List[str] = [], is_available: bool = False) -> list:
"""
返回用户的模型列表。
参数:
secret_key(str,可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "")。
api_type_filter(List[str], 可选): 根据apiType过滤,["chat", "completions", "embeddings", "text2image"],不填包括所有的。
is_available(bool, 可选): 是否返回可用模型列表, 默认返回所有模型。
返回:
list: 模型列表。
"""
request = GetModelListRequest()
request.apiTypefilter = api_type_filter
model = Models(secret_key=secret_key)
response = model.list(request)
models = []

for model in itertools.chain(response.result.common, response.result.custom):
if is_available and model.chargeStatus not in ["OPENED", "FREE"]:
continue
mapped_name = map_model_name(model.name)
models.append(mapped_name)
return models
86 changes: 86 additions & 0 deletions appbuilder/tests/test_get_model_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import appbuilder
from appbuilder.utils.model_util import GetModelListRequest, Models, GetModelListResponse


class TestModels(unittest.TestCase):
def setUp(self):
"""
设置环境变量。
Args:
None
Returns:
None.
"""
self.model = Models()

def get_model_list(self):
"""
get_model_list方法单测
Args:
None
Returns:
None
"""
response = appbuilder.get_model_list(api_type_filter=["chat"])
self.assertIsNotNone(response)
self.assertIsInstance(response, GetModelListResponse)
self.assertTrue(response.success)

def test_list(self):
"""
list方法单测
Args:
None
Returns:
None
"""

request = GetModelListRequest()
response = self.model.list(request)
self.assertIsNotNone(response)
self.assertIsInstance(response, GetModelListResponse)

def test_check_service_error(self):
"""
check_service_error方法单测
Args:
None
Returns:
None
"""
data = {'error_msg': 'Error', 'error_code': 1}
request_id = "request_id"
with self.assertRaises(appbuilder.AppBuilderServerException):
self.model._check_service_error(request_id, data)
data = {'error_msg': 'No Error', 'error_code': 0}
self.assertIsNone(self.model._check_service_error(request_id, data))


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit d12a6a4

Please sign in to comment.