diff --git a/appbuilder/core/components/handwrite_ocr/component.py b/appbuilder/core/components/handwrite_ocr/component.py index 1d7b24a00..11a126c04 100644 --- a/appbuilder/core/components/handwrite_ocr/component.py +++ b/appbuilder/core/components/handwrite_ocr/component.py @@ -18,6 +18,7 @@ from appbuilder.core.components.handwrite_ocr.model import * from appbuilder.core.message import Message from appbuilder.core._client import HTTPClient +from appbuilder.core import utils class HandwriteOCR(Component): r""" 手写文字识别组件 @@ -106,7 +107,10 @@ def tool_eval(self, name: str, streaming: bool, **kwargs): file_names = kwargs.get("files") file_urls = kwargs.get("file_urls", {}) for file_name in file_names: - file_url = file_urls.get(file_name, None) + if utils.is_url(file_name): + file_url = file_name + else: + file_url = file_urls.get(file_name, None) if file_url is None: raise InvalidRequestArgumentError(f"file {file_name} url does not exist") req = HandwriteOCRRequest() diff --git a/appbuilder/core/components/mix_card_ocr/component.py b/appbuilder/core/components/mix_card_ocr/component.py index f701a573d..93f352bbd 100644 --- a/appbuilder/core/components/mix_card_ocr/component.py +++ b/appbuilder/core/components/mix_card_ocr/component.py @@ -13,6 +13,8 @@ r"""身份证混贴识别组件""" import base64 import json + +from appbuilder.core import utils from appbuilder.core._client import HTTPClient from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError from appbuilder.core.component import Component @@ -154,7 +156,10 @@ def tool_eval(self, name: str, streaming: bool, **kwargs): file_names = kwargs.get("files") file_urls = kwargs.get("file_urls", {}) for file_name in file_names: - file_url = file_urls.get(file_name, None) + if utils.is_url(file_name): + file_url = file_name + else: + file_url = file_urls.get(file_name, None) if file_url is None: raise InvalidRequestArgumentError(f"file {file_name} url does not exist") diff --git a/appbuilder/core/components/qrcode_ocr/component.py b/appbuilder/core/components/qrcode_ocr/component.py index 40682301c..28f58bc26 100644 --- a/appbuilder/core/components/qrcode_ocr/component.py +++ b/appbuilder/core/components/qrcode_ocr/component.py @@ -17,6 +17,7 @@ import base64 import json +from appbuilder.core import utils from appbuilder.core.component import Component from appbuilder.core.components.qrcode_ocr.model import * from appbuilder.core.message import Message @@ -154,7 +155,10 @@ def tool_eval(self, name: str, streaming: bool, **kwargs): file_urls = kwargs.get("file_urls", {}) for file_name in file_names: - file_url = file_urls.get(file_name, None) + if utils.is_url(file_name): + file_url = file_name + else: + file_url = file_urls.get(file_name, None) if file_url is None: raise InvalidRequestArgumentError(f"file {file_name} url does not exist") req = QRcodeRequest() diff --git a/appbuilder/core/components/table_ocr/component.py b/appbuilder/core/components/table_ocr/component.py index e809f961c..0947d36be 100644 --- a/appbuilder/core/components/table_ocr/component.py +++ b/appbuilder/core/components/table_ocr/component.py @@ -17,6 +17,7 @@ import base64 import json +from appbuilder.core import utils from appbuilder.core.component import Component from appbuilder.core.components.table_ocr.model import * from appbuilder.core.message import Message @@ -180,7 +181,10 @@ def tool_eval(self, name: str, streaming: bool, **kwargs): file_names = kwargs.get("files") file_urls = kwargs.get("file_urls", {}) for file_name in file_names: - file_url = file_urls.get(file_name, None) + if utils.is_url(file_name): + file_url = file_name + else: + file_url = file_urls.get(file_name, None) if file_url is None: raise InvalidRequestArgumentError(f"file {file_name} url does not exist") req = TableOCRRequest() diff --git a/appbuilder/core/utils.py b/appbuilder/core/utils.py index 91256dfc7..8b6c40859 100644 --- a/appbuilder/core/utils.py +++ b/appbuilder/core/utils.py @@ -13,7 +13,7 @@ # limitations under the License. import itertools from typing import List - +from urllib.parse import urlparse from appbuilder.core._client import HTTPClient from appbuilder.core._exception import TypeNotSupportedException, ModelNotSupportedException from appbuilder.utils.model_util import GetModelListRequest, Models, model_name_mapping @@ -63,6 +63,16 @@ def convert_cloudhub_url(client: HTTPClient, qianfan_url: str) -> str: return "{}/{}{}".format(client.gateway, cloudhub_url_prefix, url_suffix) +def is_url(string): + """ + 判断字符串是否是URL + :param string: + :return: + """ + result = urlparse(string) + return all([result.scheme, result.netloc]) + + class ModelInfo: """ 模型信息类 """