diff --git a/app/schema.graphql b/app/schema.graphql index 7d78c9a9ac..ce485960fb 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -1401,6 +1401,10 @@ type PromptChatTemplate { messages: [PromptTemplateMessage!]! } +input PromptChatTemplateInput { + messages: [TextPromptMessageInput!]! +} + """A connection to a list of items.""" type PromptConnection { """Pagination data for this connection""" @@ -1491,7 +1495,7 @@ input PromptVersionInput { description: String = null templateType: PromptTemplateType! templateFormat: PromptTemplateFormat! - template: JSON! + template: PromptChatTemplateInput! invocationParameters: JSON! = {} tools: [ToolDefinitionInput!]! = [] outputSchema: JSONSchemaInput = null @@ -1847,6 +1851,11 @@ type TextPromptMessage { content: String! } +input TextPromptMessageInput { + role: PromptMessageRole! + content: String! +} + input TimeRange { """The start of the time range""" start: DateTime! diff --git a/src/phoenix/server/api/input_types/PromptVersionInput.py b/src/phoenix/server/api/input_types/PromptVersionInput.py index 3bde6c7abf..9d0f77e086 100644 --- a/src/phoenix/server/api/input_types/PromptVersionInput.py +++ b/src/phoenix/server/api/input_types/PromptVersionInput.py @@ -4,15 +4,15 @@ from strawberry.scalars import JSON from phoenix.server.api.helpers.prompts.models import ( + PromptChatTemplateV1, PromptTemplateFormat, PromptTemplateType, -) -from phoenix.server.api.helpers.prompts.models import ( - PromptToolDefinition as ToolDefinitionModel, + PromptToolDefinition, + TextPromptMessage, ) -@strawberry.experimental.pydantic.input(ToolDefinitionModel) +@strawberry.experimental.pydantic.input(PromptToolDefinition) class ToolDefinitionInput: definition: JSON @@ -22,12 +22,23 @@ class JSONSchemaInput: definition: JSON +@strawberry.experimental.pydantic.input(TextPromptMessage) +class TextPromptMessageInput: + role: strawberry.auto + content: strawberry.auto + + +@strawberry.experimental.pydantic.input(PromptChatTemplateV1) +class PromptChatTemplateInput: + messages: list[TextPromptMessageInput] + + @strawberry.input class PromptVersionInput: description: Optional[str] = None template_type: PromptTemplateType template_format: PromptTemplateFormat - template: JSON + template: PromptChatTemplateInput invocation_parameters: JSON = strawberry.field(default_factory=dict) tools: list[ToolDefinitionInput] = strawberry.field(default_factory=list) output_schema: Optional[JSONSchemaInput] = None diff --git a/src/phoenix/server/api/mutations/prompt_mutations.py b/src/phoenix/server/api/mutations/prompt_mutations.py index 2e8a623bcd..47247fc466 100644 --- a/src/phoenix/server/api/mutations/prompt_mutations.py +++ b/src/phoenix/server/api/mutations/prompt_mutations.py @@ -46,7 +46,9 @@ async def create_prompt(self, info: Info[Context, None], input: CreatePromptInpu if input.prompt_version.output_schema is not None else None ) - template = PromptChatTemplateV1.model_validate(input.prompt_version.template).dict() + template = PromptChatTemplateV1.model_validate( + strawberry.asdict(input.prompt_version.template) + ).dict() except ValidationError as error: raise BadRequest(str(error)) async with info.context.db() as session: diff --git a/tests/unit/server/api/mutations/test_prompt_mutations.py b/tests/unit/server/api/mutations/test_prompt_mutations.py index 0f5c2b760e..4fcaf2699b 100644 --- a/tests/unit/server/api/mutations/test_prompt_mutations.py +++ b/tests/unit/server/api/mutations/test_prompt_mutations.py @@ -58,7 +58,7 @@ class TestPromptMutations: "description": "prompt-version-description", "templateType": "CHAT", "templateFormat": "MUSTACHE", - "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, "modelProvider": "openai", "modelName": "o1-mini", @@ -78,7 +78,7 @@ class TestPromptMutations: "description": "prompt-version-description", "templateType": "CHAT", "templateFormat": "MUSTACHE", - "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, "modelProvider": "openai", "modelName": "o1-mini", @@ -99,7 +99,7 @@ class TestPromptMutations: "description": "prompt-version-description", "templateType": "CHAT", "templateFormat": "MUSTACHE", - "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, "modelProvider": "openai", "modelName": "o1-mini", @@ -163,7 +163,7 @@ async def test_create_prompt_fails_on_name_conflict( "description": "prompt-version-description", "templateType": "CHAT", "templateFormat": "MUSTACHE", - "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, "modelProvider": "openai", "modelName": "o1-mini", @@ -183,33 +183,6 @@ async def test_create_prompt_fails_on_name_conflict( @pytest.mark.parametrize( "variables,expected_error", [ - pytest.param( - { - "input": { - "name": "another-prompt-name", - "description": "prompt-description", - "promptVersion": { - "description": "prompt-version-description", - "templateType": "CHAT", - "templateFormat": "MUSTACHE", - "template": { - "messages": [ - { - "role": "user", - "content": "hello world", - "extra_key": "test_value", - } - ] - }, - "invocationParameters": {"temperature": 0.4}, - "modelProvider": "openai", - "modelName": "o1-mini", - }, - } - }, - "extra_key", - id="extra-key-in-message", - ), pytest.param( { "input": { @@ -219,7 +192,7 @@ async def test_create_prompt_fails_on_name_conflict( "description": "prompt-version-description", "templateType": "CHAT", "templateFormat": "MUSTACHE", - "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, "modelProvider": "openai", "modelName": "o1-mini", @@ -241,7 +214,7 @@ async def test_create_prompt_fails_on_name_conflict( "description": "prompt-version-description", "templateType": "CHAT", "templateFormat": "MUSTACHE", - "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, "modelProvider": "openai", "modelName": "o1-mini",