Skip to content

Commit

Permalink
修改references类型的保留字段
Browse files Browse the repository at this point in the history
  • Loading branch information
yepeiwen01 committed Dec 23, 2024
1 parent 1a03ce2 commit 6155800
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
13 changes: 3 additions & 10 deletions python/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,11 @@ class OralText(BaseModel, extra='allow'):

class References(BaseModel, extra='allow'):
type: str = Field(default="", description="类型")
resource_type: str = Field(default="", description="资源类型")
icon: str = Field(default="", description="站点图标")
site_name: str = Field(default="", description="站点名")
source: str = Field(default="", description="来源")
doc_id: str = Field(default="", description="文档id")
title: str = Field(default="", description="标题")
content: str = Field(default="", description="内容")
image_content: str = Field(default="", description="图片内容")
mock_id: Optional[str] = Field(default="", description="模拟数据id")
image_url: str = Field(default="", description="图片url")
video_url: str = Field(default="", description="视频url")
extra: Optional[dict] = Field(default={}, description="其他信息")


class Image(BaseModel, extra='allow'):
Expand Down Expand Up @@ -548,8 +542,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra
elif type == "files":
key_list = ["filename", "url"]
elif type == "references":
key_list = ["type", "resource_type", "icon", "site_name", "source",
"doc_id", "title", "content", "image_content", "image_url", "video_url"]
key_list = ["type", "source", "doc_id", "title", "content"]
elif type == "image":
key_list = ["filename", "url"]
elif type == "chart":
Expand All @@ -562,7 +555,7 @@ def create_output(cls, type, text, role="tool", name="", visible_scope="all", ra
key_list = ["thought", "name", "arguments"]
else:
raise ValueError("Unknown type: {}".format(type))
# assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list)
assert all(key in text for key in key_list), "all keys:{} must be included in the text field".format(key_list)
else:
raise ValueError("text must be str or dict")

Expand Down
3 changes: 3 additions & 0 deletions python/tests/test_base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_valid_output_with_dict(self):
output8 = self.component.create_output(type="audio", text={"filename": "file.mp3", "url": "http://www.baidu.com"})
output9 = self.component.create_output(type="plan", text={"detail": "hello", "steps":[{"name": "1", "arguments": {"query": "a", "chat_history": "world"}}]})
output10 = self.component.create_output(type="function_call", text={"thought": "hello", "name": "AppBuilder", "arguments": {"query": "a", "chat_history": "world"}})
output11 = self.component.create_output(type="references", text={"type": "engine", "doc_id": "1", "content": "hello, world", "title": "Have a nice day", "source": "bing", "extra": {"key": "value"}})
self.assertIsInstance(output1, ComponentOutput)
self.assertIsInstance(output2, ComponentOutput)
self.assertIsInstance(output3, ComponentOutput)
Expand All @@ -42,6 +43,8 @@ def test_valid_output_with_dict(self):
self.assertIsInstance(output8, ComponentOutput)
self.assertIsInstance(output9, ComponentOutput)
self.assertIsInstance(output10, ComponentOutput)
self.assertIsInstance(output11, ComponentOutput)
self.assertEqual(output11.content[0].text.extra["key"], "value")

def test_valid_output_type_with_same_key(self):
output1 = self.component.create_output(type="urls", text={"url": "http://www.baidu.com"})
Expand Down

0 comments on commit 6155800

Please sign in to comment.