From 698322d3aeb937a5cd9e75f61c9b50b167e290c6 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 Jan 2025 22:43:24 -0800 Subject: [PATCH] improve tests --- app/schema.graphql | 6 +- .../api/input_types/PromptVersionInput.py | 2 +- .../api/mutations/test_prompt_mutations.py | 193 ++++++++++++++---- 3 files changed, 157 insertions(+), 44 deletions(-) diff --git a/app/schema.graphql b/app/schema.graphql index bdfc0ef936..4d0fec83bc 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -1484,15 +1484,15 @@ type PromptVersionEdge { } input PromptVersionInput { - invocationParameters: JSON! = {} description: String = null templateType: PromptTemplateType! templateFormat: PromptTemplateFormat! template: JSON! - tools: [ToolDefinitionInput!]! + invocationParameters: JSON! = {} + tools: [ToolDefinitionInput!]! = [] outputSchema: JSON = null - modelName: String! modelProvider: String! + modelName: String! } type PromptVersionTag implements Node { diff --git a/src/phoenix/server/api/input_types/PromptVersionInput.py b/src/phoenix/server/api/input_types/PromptVersionInput.py index 392f7909b5..0e0bb4cf66 100644 --- a/src/phoenix/server/api/input_types/PromptVersionInput.py +++ b/src/phoenix/server/api/input_types/PromptVersionInput.py @@ -19,7 +19,7 @@ class ToolDefinitionInput: @strawberry.input class JSONSchemaInput: - schema: JSON + definition: JSON @strawberry.input diff --git a/tests/unit/server/api/mutations/test_prompt_mutations.py b/tests/unit/server/api/mutations/test_prompt_mutations.py index 0ca16fd8ed..0dc694c47f 100644 --- a/tests/unit/server/api/mutations/test_prompt_mutations.py +++ b/tests/unit/server/api/mutations/test_prompt_mutations.py @@ -1,4 +1,6 @@ -from typing import Any +from typing import Any, Optional + +import pytest from phoenix.server.types import DbSessionFactory from tests.unit.graphql import AsyncGraphQLClient @@ -44,24 +46,81 @@ class TestPromptMutations: } """ + @pytest.mark.parametrize( + "variables,expected_tools,expected_output_schema", + [ + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + }, + } + }, + [], + None, + id="basic-input", + ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [{"definition": {"foo": "bar"}}], + }, + } + }, + [{"definition": {"foo": "bar"}}], + None, + id="with-tools", + ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "outputSchema": {"definition": {"foo": "bar"}}, + }, + } + }, + [], + {"definition": {"foo": "bar"}}, + id="with-output-schema", + ), + ], + ) async def test_create_prompt_succeeds_with_valid_input( - self, db: DbSessionFactory, gql_client: AsyncGraphQLClient + self, + db: DbSessionFactory, + gql_client: AsyncGraphQLClient, + variables: dict[str, Any], + expected_tools: list[dict[str, Any]], + expected_output_schema: Optional[dict[str, Any]], ) -> None: - variables: dict[str, Any] = { - "input": { - "name": "prompt-name", - "description": "prompt-description", - "promptVersion": { - "description": "prompt-version-description", - "templateType": "CHAT", - "templateFormat": "MUSTACHE", - "template": {"messages": [{"role": "user", "content": "hello world"}]}, - "invocationParameters": {"temperature": 0.4}, - "modelProvider": "openai", - "modelName": "o1-mini", - }, - } - } result = await gql_client.execute(self.MUTATION, variables) assert not result.errors assert result.data is not None @@ -79,8 +138,8 @@ async def test_create_prompt_succeeds_with_valid_input( assert prompt_version.pop("modelProvider") == "openai" assert prompt_version.pop("modelName") == "o1-mini" assert prompt_version.pop("invocationParameters") == {"temperature": 0.4} - assert prompt_version.pop("tools") == [] - assert prompt_version.pop("outputSchema") is None + assert prompt_version.pop("tools") == expected_tools + assert prompt_version.pop("outputSchema") == expected_output_schema assert isinstance(prompt_version.pop("id"), str) # Verify messages @@ -121,28 +180,82 @@ async def test_create_prompt_fails_on_name_conflict( assert result.errors[0].message == "A prompt named 'prompt-name' already exists" assert result.data is None - async def test_create_prompt_with_invalid_input(self, gql_client: AsyncGraphQLClient) -> None: - """Test that creating a prompt with invalid input raises an error.""" - variables = { - "input": { - "name": "another-prompt-name", - "description": "prompt-description", - "promptVersion": { - "description": "prompt-version-description", - "templateType": "CHAT", - "templateFormat": "MUSTACHE", - "template": { - "messages": [ - {"role": "user", "content": "hello world", "extra_key": "test_value"} - ] - }, - "invocationParameters": {"temperature": 0.4}, - "modelProvider": "openai", - "modelName": "o1-mini", + @pytest.mark.parametrize( + "variables,expected_error", + [ + pytest.param( + { + "input": { + "name": "another-prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": { + "messages": [ + { + "role": "user", + "content": "hello world", + "extra_key": "test_value", + } + ] + }, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + }, + } }, - } - } + "extra_key", + id="extra-key-in-message", + ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [{"definition": ["foo", "bar"]}], # mispelled key + }, + } + }, + "Input should be a valid dictionary", + id="tools-not-a-dict", + ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "user", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "outputSchema": {"unknown_key": {"hello": "world"}}, + }, + } + }, + "Field required", + id="invalid-output-schema", + ), + ], + ) + async def test_create_prompt_with_invalid_input( + self, gql_client: AsyncGraphQLClient, variables: dict[str, Any], expected_error: str + ) -> None: result = await gql_client.execute(self.MUTATION, variables) assert len(result.errors) == 1 - assert "extra_key" in result.errors[0].message + assert expected_error in result.errors[0].message assert result.data is None