diff --git a/python/core/component.py b/python/core/component.py index 9a37c195..4254bbef 100644 --- a/python/core/component.py +++ b/python/core/component.py @@ -127,7 +127,7 @@ class FunctionCall(BaseModel, extra='allow'): arguments: dict = Field(default={}, description="参数列表") class Json(BaseModel, extra='allow'): - data: dict = Field(default="", description="json数据") + data: str = Field(default="", description="json数据") class Content(BaseModel): name: str = Field(default="", @@ -517,6 +517,8 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra text = {"url": text} elif type == "oral_text": text = {"info": text} + elif type == "json": + text = {"data": text} else: raise ValueError("Only when type=text/code/urls/oral_text, string text is allowed! Please give dict text") elif isinstance(text, dict): @@ -542,8 +544,6 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra key_list = ["detail", "steps"] elif type == "function_call": key_list = ["thought", "name", "arguments"] - elif type == "json": - key_list = ["data"] else: raise ValueError("Unknown type: {}".format(type)) # assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list) diff --git a/python/tests/test_core_components.py b/python/tests/test_core_components.py new file mode 100644 index 00000000..ebeb1544 --- /dev/null +++ b/python/tests/test_core_components.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import unittest + +from appbuilder.core.component import Component + + +@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "") +class TestObjectRecognition(unittest.TestCase): + def setUp(self): + self.com = Component() + self.json = json.dumps({ + "type": "object_recognition", + "text": "https://baidu.com/1.jpg", + }) + + def test_create_output(self): + result = self.com.create_output( + type="json", + text=self.json, + ) + print(result) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file