diff --git a/python/core/component.py b/python/core/component.py index 23bcbac7..5628ed9f 100644 --- a/python/core/component.py +++ b/python/core/component.py @@ -83,17 +83,11 @@ class OralText(BaseModel, extra='allow'): class References(BaseModel, extra='allow'): type: str = Field(default="", description="类型") - resource_type: str = Field(default="", description="资源类型") - icon: str = Field(default="", description="站点图标") - site_name: str = Field(default="", description="站点名") source: str = Field(default="", description="来源") doc_id: str = Field(default="", description="文档id") title: str = Field(default="", description="标题") content: str = Field(default="", description="内容") - image_content: str = Field(default="", description="图片内容") - mock_id: Optional[str] = Field(default="", description="模拟数据id") - image_url: str = Field(default="", description="图片url") - video_url: str = Field(default="", description="视频url") + extra: Optional[dict] = Field(default={}, description="其他信息") class Image(BaseModel, extra='allow'): @@ -548,8 +542,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra elif type == "files": key_list = ["filename", "url"] elif type == "references": - key_list = ["type", "resource_type", "icon", "site_name", "source", - "doc_id", "title", "content", "image_content", "image_url", "video_url"] + key_list = ["type", "source", "doc_id", "title", "content"] elif type == "image": key_list = ["filename", "url"] elif type == "chart": @@ -562,7 +555,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra key_list = ["thought", "name", "arguments"] else: raise ValueError("Unknown type: {}".format(type)) - # assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list) + assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list) else: raise ValueError("text must be str or dict") diff --git a/python/tests/test_base_component.py b/python/tests/test_base_component.py index 94d5ff4e..a560cf78 100644 --- a/python/tests/test_base_component.py +++ b/python/tests/test_base_component.py @@ -32,6 +32,7 @@ def test_valid_output_with_dict(self): output8 = self.component.create_output(type="audio", text={"filename": "file.mp3", "url": "http://www.baidu.com"}) output9 = self.component.create_output(type="plan", text={"detail": "hello", "steps":[{"name": "1", "arguments": {"query": "a", "chat_history": "world"}}]}) output10 = self.component.create_output(type="function_call", text={"thought": "hello", "name": "AppBuilder", "arguments": {"query": "a", "chat_history": "world"}}) + output11 = self.component.create_output(type="references", text={"type": "engine", "doc_id": "1", "content": "hello, world", "title": "Have a nice day", "source": "bing", "extra": {"key": "value"}}) self.assertIsInstance(output1, ComponentOutput) self.assertIsInstance(output2, ComponentOutput) self.assertIsInstance(output3, ComponentOutput) @@ -42,6 +43,8 @@ def test_valid_output_with_dict(self): self.assertIsInstance(output8, ComponentOutput) self.assertIsInstance(output9, ComponentOutput) self.assertIsInstance(output10, ComponentOutput) + self.assertIsInstance(output11, ComponentOutput) + self.assertEqual(output11.content[0].text.extra["key"], "value") def test_valid_output_type_with_same_key(self): output1 = self.component.create_output(type="urls", text={"url": "http://www.baidu.com"})