From 79441e0ee8d203c02b742070c5e624dc71d82b3c Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 29 Jan 2025 16:25:50 -0800 Subject: [PATCH] fix(prompts): use discriminated union for content part --- schemas/openapi.json | 41 +++++++++----- .../server/api/helpers/prompts/models.py | 16 +++--- .../api/input_types/PromptVersionInput.py | 55 +++++++++++++++++++ .../server/api/mutations/prompt_mutations.py | 7 ++- .../server/api/routers/v1/test_prompts.py | 3 +- 5 files changed, 98 insertions(+), 24 deletions(-) diff --git a/schemas/openapi.json b/schemas/openapi.json index eabea2a496..cd6c117665 100644 --- a/schemas/openapi.json +++ b/schemas/openapi.json @@ -1982,15 +1982,16 @@ "type": { "type": "string", "const": "image", - "title": "Type", - "default": "image" + "title": "Type" }, "image": { "$ref": "#/components/schemas/ImageContentValue" } }, + "additionalProperties": false, "type": "object", "required": [ + "type", "image" ], "title": "ImageContentPart" @@ -2179,8 +2180,7 @@ "version": { "type": "string", "const": "chat-template-v1", - "title": "Version", - "default": "chat-template-v1" + "title": "Version" }, "messages": { "items": { @@ -2193,6 +2193,7 @@ "additionalProperties": false, "type": "object", "required": [ + "version", "messages" ], "title": "PromptChatTemplateV1" @@ -2204,7 +2205,7 @@ }, "content": { "items": { - "anyOf": [ + "oneOf": [ { "$ref": "#/components/schemas/TextContentPart" }, @@ -2217,7 +2218,16 @@ { "$ref": "#/components/schemas/ToolResultContentPart" } - ] + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "image": "#/components/schemas/ImageContentPart", + "text": "#/components/schemas/TextContentPart", + "tool_call": "#/components/schemas/ToolCallContentPart", + "tool_result": "#/components/schemas/ToolResultContentPart" + } + } }, "type": "array", "title": "Content" @@ -2260,8 +2270,7 @@ "version": { "type": "string", "const": "string-template-v1", - "title": "Version", - "default": "string-template-v1" + "title": "Version" }, "template": { "type": "string", @@ -2271,6 +2280,7 @@ "additionalProperties": false, "type": "object", "required": [ + "version", "template" ], "title": "PromptStringTemplateV1" @@ -2508,15 +2518,16 @@ "type": { "type": "string", "const": "text", - "title": "Type", - "default": "text" + "title": "Type" }, "text": { "$ref": "#/components/schemas/TextContentValue" } }, + "additionalProperties": false, "type": "object", "required": [ + "type", "text" ], "title": "TextContentPart" @@ -2539,15 +2550,16 @@ "type": { "type": "string", "const": "tool_call", - "title": "Type", - "default": "tool_call" + "title": "Type" }, "tool_call": { "$ref": "#/components/schemas/ToolCallContentValue" } }, + "additionalProperties": false, "type": "object", "required": [ + "type", "tool_call" ], "title": "ToolCallContentPart" @@ -2598,15 +2610,16 @@ "type": { "type": "string", "const": "tool_result", - "title": "Type", - "default": "tool_result" + "title": "Type" }, "tool_result": { "$ref": "#/components/schemas/ToolResultContentValue" } }, + "additionalProperties": false, "type": "object", "required": [ + "type", "tool_result" ], "title": "ToolResultContentPart" diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index 298de9110a..0c32453197 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -49,7 +49,7 @@ class PromptModel(BaseModel): ) -class PartBase(BaseModel): +class PartBase(PromptModel): type: Literal["text", "image", "tool", "tool_call", "tool_result"] @@ -58,7 +58,7 @@ class TextContentValue(BaseModel): class TextContentPart(PartBase): - type: Literal["text"] = Field(default="text") + type: Literal["text"] text: TextContentValue @@ -69,7 +69,7 @@ class ImageContentValue(BaseModel): class ImageContentPart(PartBase): - type: Literal["image"] = Field(default="image") + type: Literal["image"] # the image data image: ImageContentValue @@ -86,7 +86,7 @@ class ToolCallContentValue(BaseModel): class ToolCallContentPart(PartBase): - type: Literal["tool_call"] = Field(default="tool_call") + type: Literal["tool_call"] # the identifier of the tool call function tool_call: ToolCallContentValue @@ -97,13 +97,13 @@ class ToolResultContentValue(BaseModel): class ToolResultContentPart(PartBase): - type: Literal["tool_result"] = Field(default="tool_result") + type: Literal["tool_result"] tool_result: ToolResultContentValue ContentPart: TypeAlias = Annotated[ Union[TextContentPart, ImageContentPart, ToolCallContentPart, ToolResultContentPart], - Field(), + Field(..., discriminator="type"), ] @@ -113,12 +113,12 @@ class PromptMessage(PromptModel): class PromptChatTemplateV1(PromptModel): - version: Literal["chat-template-v1"] = "chat-template-v1" + version: Literal["chat-template-v1"] messages: list[PromptMessage] class PromptStringTemplateV1(PromptModel): - version: Literal["string-template-v1"] = "string-template-v1" + version: Literal["string-template-v1"] template: str diff --git a/src/phoenix/server/api/input_types/PromptVersionInput.py b/src/phoenix/server/api/input_types/PromptVersionInput.py index 33b50cd912..c7415323d5 100644 --- a/src/phoenix/server/api/input_types/PromptVersionInput.py +++ b/src/phoenix/server/api/input_types/PromptVersionInput.py @@ -1,15 +1,23 @@ from typing import Optional import strawberry +from strawberry import UNSET from strawberry.scalars import JSON from phoenix.server.api.helpers.prompts.models import ( + ContentPart, + ImageContentPart, ImageContentValue, + PromptChatTemplateV1, + PromptMessage, PromptTemplateFormat, PromptToolDefinition, + TextContentPart, TextContentValue, + ToolCallContentPart, ToolCallContentValue, ToolCallFunction, + ToolResultContentPart, ToolResultContentValue, ) from phoenix.server.api.helpers.prompts.models import ( @@ -85,3 +93,50 @@ class ChatPromptVersionInput: output_schema: Optional[OutputSchemaInput] = None model_provider: str model_name: str + + +def to_pydantic_prompt_chat_template_v1( + prompt_chat_template_input: PromptChatTemplateInput, +) -> PromptChatTemplateV1: + return PromptChatTemplateV1( + version="chat-template-v1", + messages=[ + to_pydantic_prompt_message(message) for message in prompt_chat_template_input.messages + ], + ) + + +def to_pydantic_prompt_message(prompt_message_input: PromptMessageInput) -> PromptMessage: + return PromptMessage( + role=prompt_message_input.role, + content=[ + to_pydantic_content_part(content_part) for content_part in prompt_message_input.content + ], + ) + + +def to_pydantic_content_part(content_part_input: ContentPartInput) -> ContentPart: + content_part_cls: type[ContentPart] + if content_part_input.text is not UNSET: + content_part_cls = TextContentPart + content_part_type = "text" + elif content_part_input.image is not UNSET: + content_part_cls = ImageContentPart + content_part_type = "image" + elif content_part_input.tool_call is not UNSET: + content_part_cls = ToolCallContentPart + content_part_type = "tool_call" + elif content_part_input.tool_result is not UNSET: + content_part_cls = ToolResultContentPart + content_part_type = "tool_result" + else: + raise ValueError("content part input has no content") + content_part_data = { + k: v for k, v in strawberry.asdict(content_part_input).items() if v is not UNSET + } + return content_part_cls.model_validate( + { + "type": content_part_type, + **content_part_data, + } + ) diff --git a/src/phoenix/server/api/mutations/prompt_mutations.py b/src/phoenix/server/api/mutations/prompt_mutations.py index a2c3c17219..b4b899724c 100644 --- a/src/phoenix/server/api/mutations/prompt_mutations.py +++ b/src/phoenix/server/api/mutations/prompt_mutations.py @@ -18,7 +18,10 @@ PromptToolsV1, PromptVersion, ) -from phoenix.server.api.input_types.PromptVersionInput import ChatPromptVersionInput +from phoenix.server.api.input_types.PromptVersionInput import ( + ChatPromptVersionInput, + to_pydantic_prompt_chat_template_v1, +) from phoenix.server.api.mutations.prompt_version_tag_mutations import ( SetPromptVersionTagInput, upsert_prompt_version_tag, @@ -89,6 +92,7 @@ async def create_chat_prompt( **{ **strawberry.asdict(input.prompt_version), "tools": tools, + "template": to_pydantic_prompt_chat_template_v1(input.prompt_version.template), }, template_type="CHAT", user_id=user_id, @@ -144,6 +148,7 @@ async def create_chat_prompt_version( **{ **strawberry.asdict(input.prompt_version), "tools": tools, + "template": to_pydantic_prompt_chat_template_v1(input.prompt_version.template), }, template_type="CHAT", user_id=user_id, diff --git a/tests/unit/server/api/routers/v1/test_prompts.py b/tests/unit/server/api/routers/v1/test_prompts.py index 8d0accaf4b..6426124e11 100644 --- a/tests/unit/server/api/routers/v1/test_prompts.py +++ b/tests/unit/server/api/routers/v1/test_prompts.py @@ -137,6 +137,7 @@ async def _insert_prompt_versions( template_format="MUSTACHE", template=PromptChatTemplateV1.model_validate( { + "version": "chat-template-v1", "messages": [ { "role": "USER", @@ -168,7 +169,7 @@ async def _insert_prompt_versions( }, ], } - ] + ], } ), invocation_parameters={},