Skip to content

Commit

Permalink
improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Jan 3, 2025
1 parent 6021633 commit 698322d
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 44 deletions.
6 changes: 3 additions & 3 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1484,15 +1484,15 @@ type PromptVersionEdge {
}

input PromptVersionInput {
invocationParameters: JSON! = {}
description: String = null
templateType: PromptTemplateType!
templateFormat: PromptTemplateFormat!
template: JSON!
tools: [ToolDefinitionInput!]!
invocationParameters: JSON! = {}
tools: [ToolDefinitionInput!]! = []
outputSchema: JSON = null
modelName: String!
modelProvider: String!
modelName: String!
}

type PromptVersionTag implements Node {
Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/server/api/input_types/PromptVersionInput.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ToolDefinitionInput:

@strawberry.input
class JSONSchemaInput:
schema: JSON
definition: JSON


@strawberry.input
Expand Down
193 changes: 153 additions & 40 deletions tests/unit/server/api/mutations/test_prompt_mutations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any
from typing import Any, Optional

import pytest

from phoenix.server.types import DbSessionFactory
from tests.unit.graphql import AsyncGraphQLClient
Expand Down Expand Up @@ -44,24 +46,81 @@ class TestPromptMutations:
}
"""

@pytest.mark.parametrize(
"variables,expected_tools,expected_output_schema",
[
pytest.param(
{
"input": {
"name": "prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
},
}
},
[],
None,
id="basic-input",
),
pytest.param(
{
"input": {
"name": "prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
"tools": [{"definition": {"foo": "bar"}}],
},
}
},
[{"definition": {"foo": "bar"}}],
None,
id="with-tools",
),
pytest.param(
{
"input": {
"name": "prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
"outputSchema": {"definition": {"foo": "bar"}},
},
}
},
[],
{"definition": {"foo": "bar"}},
id="with-output-schema",
),
],
)
async def test_create_prompt_succeeds_with_valid_input(
self, db: DbSessionFactory, gql_client: AsyncGraphQLClient
self,
db: DbSessionFactory,
gql_client: AsyncGraphQLClient,
variables: dict[str, Any],
expected_tools: list[dict[str, Any]],
expected_output_schema: Optional[dict[str, Any]],
) -> None:
variables: dict[str, Any] = {
"input": {
"name": "prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
},
}
}
result = await gql_client.execute(self.MUTATION, variables)
assert not result.errors
assert result.data is not None
Expand All @@ -79,8 +138,8 @@ async def test_create_prompt_succeeds_with_valid_input(
assert prompt_version.pop("modelProvider") == "openai"
assert prompt_version.pop("modelName") == "o1-mini"
assert prompt_version.pop("invocationParameters") == {"temperature": 0.4}
assert prompt_version.pop("tools") == []
assert prompt_version.pop("outputSchema") is None
assert prompt_version.pop("tools") == expected_tools
assert prompt_version.pop("outputSchema") == expected_output_schema
assert isinstance(prompt_version.pop("id"), str)

# Verify messages
Expand Down Expand Up @@ -121,28 +180,82 @@ async def test_create_prompt_fails_on_name_conflict(
assert result.errors[0].message == "A prompt named 'prompt-name' already exists"
assert result.data is None

async def test_create_prompt_with_invalid_input(self, gql_client: AsyncGraphQLClient) -> None:
"""Test that creating a prompt with invalid input raises an error."""
variables = {
"input": {
"name": "another-prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {
"messages": [
{"role": "user", "content": "hello world", "extra_key": "test_value"}
]
},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
@pytest.mark.parametrize(
"variables,expected_error",
[
pytest.param(
{
"input": {
"name": "another-prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {
"messages": [
{
"role": "user",
"content": "hello world",
"extra_key": "test_value",
}
]
},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
},
}
},
}
}
"extra_key",
id="extra-key-in-message",
),
pytest.param(
{
"input": {
"name": "prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
"tools": [{"definition": ["foo", "bar"]}], # mispelled key
},
}
},
"Input should be a valid dictionary",
id="tools-not-a-dict",
),
pytest.param(
{
"input": {
"name": "prompt-name",
"description": "prompt-description",
"promptVersion": {
"description": "prompt-version-description",
"templateType": "CHAT",
"templateFormat": "MUSTACHE",
"template": {"messages": [{"role": "user", "content": "hello world"}]},
"invocationParameters": {"temperature": 0.4},
"modelProvider": "openai",
"modelName": "o1-mini",
"outputSchema": {"unknown_key": {"hello": "world"}},
},
}
},
"Field required",
id="invalid-output-schema",
),
],
)
async def test_create_prompt_with_invalid_input(
self, gql_client: AsyncGraphQLClient, variables: dict[str, Any], expected_error: str
) -> None:
result = await gql_client.execute(self.MUTATION, variables)
assert len(result.errors) == 1
assert "extra_key" in result.errors[0].message
assert expected_error in result.errors[0].message
assert result.data is None

0 comments on commit 698322d

Please sign in to comment.