Skip to content

Commit

Permalink
fix(prompts): use discriminated union for content part
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Jan 30, 2025
1 parent afd7e39 commit 79441e0
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 24 deletions.
41 changes: 27 additions & 14 deletions schemas/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1982,15 +1982,16 @@
"type": {
"type": "string",
"const": "image",
"title": "Type",
"default": "image"
"title": "Type"
},
"image": {
"$ref": "#/components/schemas/ImageContentValue"
}
},
"additionalProperties": false,
"type": "object",
"required": [
"type",
"image"
],
"title": "ImageContentPart"
Expand Down Expand Up @@ -2179,8 +2180,7 @@
"version": {
"type": "string",
"const": "chat-template-v1",
"title": "Version",
"default": "chat-template-v1"
"title": "Version"
},
"messages": {
"items": {
Expand All @@ -2193,6 +2193,7 @@
"additionalProperties": false,
"type": "object",
"required": [
"version",
"messages"
],
"title": "PromptChatTemplateV1"
Expand All @@ -2204,7 +2205,7 @@
},
"content": {
"items": {
"anyOf": [
"oneOf": [
{
"$ref": "#/components/schemas/TextContentPart"
},
Expand All @@ -2217,7 +2218,16 @@
{
"$ref": "#/components/schemas/ToolResultContentPart"
}
]
],
"discriminator": {
"propertyName": "type",
"mapping": {
"image": "#/components/schemas/ImageContentPart",
"text": "#/components/schemas/TextContentPart",
"tool_call": "#/components/schemas/ToolCallContentPart",
"tool_result": "#/components/schemas/ToolResultContentPart"
}
}
},
"type": "array",
"title": "Content"
Expand Down Expand Up @@ -2260,8 +2270,7 @@
"version": {
"type": "string",
"const": "string-template-v1",
"title": "Version",
"default": "string-template-v1"
"title": "Version"
},
"template": {
"type": "string",
Expand All @@ -2271,6 +2280,7 @@
"additionalProperties": false,
"type": "object",
"required": [
"version",
"template"
],
"title": "PromptStringTemplateV1"
Expand Down Expand Up @@ -2508,15 +2518,16 @@
"type": {
"type": "string",
"const": "text",
"title": "Type",
"default": "text"
"title": "Type"
},
"text": {
"$ref": "#/components/schemas/TextContentValue"
}
},
"additionalProperties": false,
"type": "object",
"required": [
"type",
"text"
],
"title": "TextContentPart"
Expand All @@ -2539,15 +2550,16 @@
"type": {
"type": "string",
"const": "tool_call",
"title": "Type",
"default": "tool_call"
"title": "Type"
},
"tool_call": {
"$ref": "#/components/schemas/ToolCallContentValue"
}
},
"additionalProperties": false,
"type": "object",
"required": [
"type",
"tool_call"
],
"title": "ToolCallContentPart"
Expand Down Expand Up @@ -2598,15 +2610,16 @@
"type": {
"type": "string",
"const": "tool_result",
"title": "Type",
"default": "tool_result"
"title": "Type"
},
"tool_result": {
"$ref": "#/components/schemas/ToolResultContentValue"
}
},
"additionalProperties": false,
"type": "object",
"required": [
"type",
"tool_result"
],
"title": "ToolResultContentPart"
Expand Down
16 changes: 8 additions & 8 deletions src/phoenix/server/api/helpers/prompts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PromptModel(BaseModel):
)


class PartBase(BaseModel):
class PartBase(PromptModel):
type: Literal["text", "image", "tool", "tool_call", "tool_result"]


Expand All @@ -58,7 +58,7 @@ class TextContentValue(BaseModel):


class TextContentPart(PartBase):
type: Literal["text"] = Field(default="text")
type: Literal["text"]
text: TextContentValue


Expand All @@ -69,7 +69,7 @@ class ImageContentValue(BaseModel):


class ImageContentPart(PartBase):
type: Literal["image"] = Field(default="image")
type: Literal["image"]
# the image data
image: ImageContentValue

Expand All @@ -86,7 +86,7 @@ class ToolCallContentValue(BaseModel):


class ToolCallContentPart(PartBase):
type: Literal["tool_call"] = Field(default="tool_call")
type: Literal["tool_call"]
# the identifier of the tool call function
tool_call: ToolCallContentValue

Expand All @@ -97,13 +97,13 @@ class ToolResultContentValue(BaseModel):


class ToolResultContentPart(PartBase):
type: Literal["tool_result"] = Field(default="tool_result")
type: Literal["tool_result"]
tool_result: ToolResultContentValue


ContentPart: TypeAlias = Annotated[
Union[TextContentPart, ImageContentPart, ToolCallContentPart, ToolResultContentPart],
Field(),
Field(..., discriminator="type"),
]


Expand All @@ -113,12 +113,12 @@ class PromptMessage(PromptModel):


class PromptChatTemplateV1(PromptModel):
version: Literal["chat-template-v1"] = "chat-template-v1"
version: Literal["chat-template-v1"]
messages: list[PromptMessage]


class PromptStringTemplateV1(PromptModel):
version: Literal["string-template-v1"] = "string-template-v1"
version: Literal["string-template-v1"]
template: str


Expand Down
55 changes: 55 additions & 0 deletions src/phoenix/server/api/input_types/PromptVersionInput.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from typing import Optional

import strawberry
from strawberry import UNSET
from strawberry.scalars import JSON

from phoenix.server.api.helpers.prompts.models import (
ContentPart,
ImageContentPart,
ImageContentValue,
PromptChatTemplateV1,
PromptMessage,
PromptTemplateFormat,
PromptToolDefinition,
TextContentPart,
TextContentValue,
ToolCallContentPart,
ToolCallContentValue,
ToolCallFunction,
ToolResultContentPart,
ToolResultContentValue,
)
from phoenix.server.api.helpers.prompts.models import (
Expand Down Expand Up @@ -85,3 +93,50 @@ class ChatPromptVersionInput:
output_schema: Optional[OutputSchemaInput] = None
model_provider: str
model_name: str


def to_pydantic_prompt_chat_template_v1(
prompt_chat_template_input: PromptChatTemplateInput,
) -> PromptChatTemplateV1:
return PromptChatTemplateV1(
version="chat-template-v1",
messages=[
to_pydantic_prompt_message(message) for message in prompt_chat_template_input.messages
],
)


def to_pydantic_prompt_message(prompt_message_input: PromptMessageInput) -> PromptMessage:
return PromptMessage(
role=prompt_message_input.role,
content=[
to_pydantic_content_part(content_part) for content_part in prompt_message_input.content
],
)


def to_pydantic_content_part(content_part_input: ContentPartInput) -> ContentPart:
content_part_cls: type[ContentPart]
if content_part_input.text is not UNSET:
content_part_cls = TextContentPart
content_part_type = "text"
elif content_part_input.image is not UNSET:
content_part_cls = ImageContentPart
content_part_type = "image"
elif content_part_input.tool_call is not UNSET:
content_part_cls = ToolCallContentPart
content_part_type = "tool_call"
elif content_part_input.tool_result is not UNSET:
content_part_cls = ToolResultContentPart
content_part_type = "tool_result"
else:
raise ValueError("content part input has no content")
content_part_data = {
k: v for k, v in strawberry.asdict(content_part_input).items() if v is not UNSET
}
return content_part_cls.model_validate(
{
"type": content_part_type,
**content_part_data,
}
)
7 changes: 6 additions & 1 deletion src/phoenix/server/api/mutations/prompt_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
PromptToolsV1,
PromptVersion,
)
from phoenix.server.api.input_types.PromptVersionInput import ChatPromptVersionInput
from phoenix.server.api.input_types.PromptVersionInput import (
ChatPromptVersionInput,
to_pydantic_prompt_chat_template_v1,
)
from phoenix.server.api.mutations.prompt_version_tag_mutations import (
SetPromptVersionTagInput,
upsert_prompt_version_tag,
Expand Down Expand Up @@ -89,6 +92,7 @@ async def create_chat_prompt(
**{
**strawberry.asdict(input.prompt_version),
"tools": tools,
"template": to_pydantic_prompt_chat_template_v1(input.prompt_version.template),
},
template_type="CHAT",
user_id=user_id,
Expand Down Expand Up @@ -144,6 +148,7 @@ async def create_chat_prompt_version(
**{
**strawberry.asdict(input.prompt_version),
"tools": tools,
"template": to_pydantic_prompt_chat_template_v1(input.prompt_version.template),
},
template_type="CHAT",
user_id=user_id,
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/server/api/routers/v1/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ async def _insert_prompt_versions(
template_format="MUSTACHE",
template=PromptChatTemplateV1.model_validate(
{
"version": "chat-template-v1",
"messages": [
{
"role": "USER",
Expand Down Expand Up @@ -168,7 +169,7 @@ async def _insert_prompt_versions(
},
],
}
]
],
}
),
invocation_parameters={},
Expand Down

0 comments on commit 79441e0

Please sign in to comment.