Skip to content

Commit

Permalink
add types for create prompt mutation input
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Jan 3, 2025
1 parent 2096edc commit d105dc4
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 40 deletions.
11 changes: 10 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,10 @@ type PromptChatTemplate {
messages: [PromptTemplateMessage!]!
}

input PromptChatTemplateInput {
messages: [TextPromptMessageInput!]!
}

"""A connection to a list of items."""
type PromptConnection {
"""Pagination data for this connection"""
Expand Down Expand Up @@ -1491,7 +1495,7 @@ input PromptVersionInput {
description: String = null
templateType: PromptTemplateType!
templateFormat: PromptTemplateFormat!
template: JSON!
template: PromptChatTemplateInput!
invocationParameters: JSON! = {}
tools: [ToolDefinitionInput!]! = []
outputSchema: JSONSchemaInput = null
Expand Down Expand Up @@ -1847,6 +1851,11 @@ type TextPromptMessage {
content: String!
}

input TextPromptMessageInput {
role: PromptMessageRole!
content: String!
}

input TimeRange {
"""The start of the time range"""
start: DateTime!
Expand Down
21 changes: 16 additions & 5 deletions src/phoenix/server/api/input_types/PromptVersionInput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from strawberry.scalars import JSON

from phoenix.server.api.helpers.prompts.models import (
PromptChatTemplateV1,
PromptTemplateFormat,
PromptTemplateType,
)
from phoenix.server.api.helpers.prompts.models import (
PromptToolDefinition as ToolDefinitionModel,
PromptToolDefinition,
TextPromptMessage,
)


@strawberry.experimental.pydantic.input(ToolDefinitionModel)
@strawberry.experimental.pydantic.input(PromptToolDefinition)
class ToolDefinitionInput:
definition: JSON

Expand All @@ -22,12 +22,23 @@ class JSONSchemaInput:
definition: JSON


@strawberry.experimental.pydantic.input(TextPromptMessage)
class TextPromptMessageInput:
role: strawberry.auto
content: strawberry.auto


@strawberry.experimental.pydantic.input(PromptChatTemplateV1)
class PromptChatTemplateInput:
messages: list[TextPromptMessageInput]


@strawberry.input
class PromptVersionInput:
description: Optional[str] = None
template_type: PromptTemplateType
template_format: PromptTemplateFormat
template: JSON
template: PromptChatTemplateInput
invocation_parameters: JSON = strawberry.field(default_factory=dict)
tools: list[ToolDefinitionInput] = strawberry.field(default_factory=list)
output_schema: Optional[JSONSchemaInput] = None
Expand Down
4 changes: 3 additions & 1 deletion src/phoenix/server/api/mutations/prompt_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ async def create_prompt(self, info: Info[Context, None], input: CreatePromptInpu
if input.prompt_version.output_schema is not None
else None
)
template = PromptChatTemplateV1.model_validate(input.prompt_version.template).dict()
template = PromptChatTemplateV1.model_validate(
strawberry.asdict(input.prompt_version.template)
).dict()
except ValidationError as error:
raise BadRequest(str(error))
async with info.context.db() as session:
Expand Down
39 changes: 6 additions & 33 deletions tests/unit/server/api/mutations/test_prompt_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TestPromptMutations:
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"template": {"messages": [{"role": "USER", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
Expand All @@ -78,7 +78,7 @@ class TestPromptMutations:
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"template": {"messages": [{"role": "USER", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
Expand All @@ -99,7 +99,7 @@ class TestPromptMutations:
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"template": {"messages": [{"role": "USER", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
Expand Down Expand Up @@ -163,7 +163,7 @@ async def test_create_prompt_fails_on_name_conflict(
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"template": {"messages": [{"role": "USER", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
Expand All @@ -183,33 +183,6 @@ async def test_create_prompt_fails_on_name_conflict(
@pytest.mark.parametrize(
"variables,expected_error",
[
pytest.param(
{
"input": {
"name": "another-prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {
"messages": [
{
"role": "user",
"content": "hello world",
"extra_key": "test_value",
}
]
},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
},
}
},
"extra_key",
id="extra-key-in-message",
),
pytest.param(
{
"input": {
Expand All @@ -219,7 +192,7 @@ async def test_create_prompt_fails_on_name_conflict(
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"template": {"messages": [{"role": "USER", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
Expand All @@ -241,7 +214,7 @@ async def test_create_prompt_fails_on_name_conflict(
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"template": {"messages": [{"role": "USER", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
Expand Down

0 comments on commit d105dc4

Please sign in to comment.