From 10c11985d57655dd8d48b58feb9c6d2d115afc21 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 Jan 2025 23:48:51 -0800 Subject: [PATCH 01/15] pass valid inputs --- .../server/api/helpers/prompts/models.py | 66 ++- tests/unit/server/api/helpers/test_models.py | 501 ++++++++++++++++++ 2 files changed, 565 insertions(+), 2 deletions(-) create mode 100644 tests/unit/server/api/helpers/test_models.py diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index ebaaddac02..af85e18b77 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field from typing_extensions import TypeAlias @@ -54,7 +54,7 @@ class PromptStringTemplateV1(PromptModel): PromptTemplate: TypeAlias = Union[PromptChatTemplateV1, PromptStringTemplateV1] -class PromptJSONSchema(BaseModel): +class PromptJSONSchema(PromptModel): """A JSON schema definition used to guide an LLM's output""" definition: dict[str, Any] @@ -67,3 +67,65 @@ class PromptToolDefinition(PromptModel): class PromptToolsV1(PromptModel): version: Literal["tools-v1"] = "tools-v1" tool_definitions: list[PromptToolDefinition] = Field(..., min_length=1) + + +# Tool models +JSONSchemaDataType = Literal["string", "number", "boolean", "object", "array", "null", "integer"] + + +class Undefined: + """ + A singleton class that represents an unset or undefined value. Needed since Pydantic + can't natively distinguish between an undefined value and a value that is set to + None. + """ + + def __new__(cls) -> Any: + if not hasattr(cls, "_instance"): + cls._instance = super().__new__(cls) + return cls._instance + + +UNDEFINED: Any = Undefined() + + +JSONSchemaPropertyType = Union["JSONSchema", "JSONSchemaProperty", "JSONSchemaPropertyUnion"] + + +class JSONSchemaProperty(PromptModel): + type: JSONSchemaDataType + description: str = UNDEFINED + items: JSONSchemaPropertyType = UNDEFINED + enum: list[str] = UNDEFINED + + +class JSONSchemaPropertyUnion(PromptModel): + any_of: list[Union["JSONSchema", "JSONSchemaProperty"]] = Field(UNDEFINED, alias="anyOf") + + +class JSONSchema(PromptModel): + type: JSONSchemaDataType + description: str = UNDEFINED + properties: dict[str, JSONSchemaPropertyType] = UNDEFINED + required: list[str] = UNDEFINED + additional_properties: bool = Field(UNDEFINED, alias="additionalProperties") + + +class OpenAIFunctionDefinition(PromptModel): + """ + Based on https://github.com/openai/openai-python/blob/1e07c9d839e7e96f02d0a4b745f379a43086334c/src/openai/types/shared_params/function_definition.py#L13 + """ + + name: str + description: str = UNDEFINED + parameters: JSONSchema = UNDEFINED + strict: Optional[bool] = UNDEFINED + + +class OpenAIToolDefinition(PromptModel): + """ + Based on https://github.com/openai/openai-python/blob/1e07c9d839e7e96f02d0a4b745f379a43086334c/src/openai/types/chat/chat_completion_tool_param.py#L12 + """ + + function: OpenAIFunctionDefinition + type: Literal["function"] diff --git a/tests/unit/server/api/helpers/test_models.py b/tests/unit/server/api/helpers/test_models.py new file mode 100644 index 0000000000..e29ee2b028 --- /dev/null +++ b/tests/unit/server/api/helpers/test_models.py @@ -0,0 +1,501 @@ +from typing import Any + +import pytest + +from phoenix.server.api.helpers.prompts.models import OpenAIToolDefinition + + +@pytest.mark.parametrize( + "tool_definition", + [ + pytest.param( + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + }, + id="get-weather-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_delivery_date", + "description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'", # noqa: E501 + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The customer's order ID.", + } + }, + "required": ["order_id"], + "additionalProperties": False, + }, + }, + }, + id="get-delivery-date-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "generate_recipe", + "description": "Generate a recipe based on the user's input", + "parameters": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title of the recipe.", + }, + "ingredients": { + "type": "array", + "items": {"type": "string"}, + "description": "List of ingredients required for the recipe.", + }, + "instructions": { + "type": "array", + "items": {"type": "string"}, + "description": "Step-by-step instructions for the recipe.", + }, + }, + "required": ["title", "ingredients", "instructions"], + "additionalProperties": False, + }, + }, + }, + id="generate-recipe-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_product_recommendations", + "description": "Searches for products matching certain criteria in the database", # noqa: E501 + "parameters": { + "type": "object", + "properties": { + "categories": { + "description": "categories that could be a match", + "type": "array", + "items": { + "type": "string", + "enum": [ + "coats & jackets", + "accessories", + "tops", + "jeans & trousers", + "skirts & dresses", + "shoes", + ], + }, + }, + "colors": { + "description": "colors that could be a match, empty array if N/A", + "type": "array", + "items": { + "type": "string", + "enum": [ + "black", + "white", + "brown", + "red", + "blue", + "green", + "orange", + "yellow", + "pink", + "gold", + "silver", + ], + }, + }, + "keywords": { + "description": "keywords that should be present in the item title or description", # noqa: E501 + "type": "array", + "items": {"type": "string"}, + }, + "price_range": { + "type": "object", + "properties": { + "min": {"type": "number"}, + "max": {"type": "number"}, + }, + "required": ["min", "max"], + "additionalProperties": False, + }, + "limit": { + "type": "integer", + "description": "The maximum number of products to return, use 5 by default if nothing is specified by the user", # noqa: E501 + }, + }, + "required": ["categories", "colors", "keywords", "price_range", "limit"], + "additionalProperties": False, + }, + }, + }, + id="get-product-recommendations-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_product_details", + "description": "Fetches more details about a product", + "parameters": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": "The ID of the product to fetch details for", + } + }, + "required": ["product_id"], + "additionalProperties": False, + }, + }, + }, + id="get-product-details-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "add_to_cart", + "description": "Add items to cart when the user has confirmed their interest.", + "parameters": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": "ID of the product to add to the cart", + }, + "quantity": { + "type": "integer", + "description": "Quantity of the product to add to the cart", # noqa: E501 + }, + }, + "required": ["product_id", "quantity"], + "additionalProperties": False, + }, + } + }, + "required": ["items"], + "additionalProperties": False, + }, + }, + }, + id="add-to-cart-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_order_details", + "description": "Fetches details about a specific order", + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The ID of the order to fetch details for", + } + }, + "required": ["order_id"], + "additionalProperties": False, + }, + }, + }, + id="get-order-details-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_user_orders", + "description": "Fetches the last orders for a given user", + "parameters": { + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "The ID of the user to fetch orders for", + }, + "limit": { + "type": "integer", + "description": "The maximum number of orders to return, use 5 by default and increase the number if the relevant order is not found.", # noqa: E501 + }, + }, + "required": ["user_id", "limit"], + "additionalProperties": False, + }, + }, + }, + id="get-user-orders-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "search_faq", + "description": "Searches the FAQ for an answer to the user's question", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The question to search the FAQ for", + } + }, + "required": ["query"], + "additionalProperties": False, + }, + }, + }, + id="search-faq-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "process_return", + "description": "Processes a return and creates a return label", + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The ID of the order to process a return for", + }, + "items": { + "type": "array", + "description": "The items to return", + "items": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": "The ID of the product to return", + }, + "quantity": { + "type": "integer", + "description": "The quantity of the product to return", + }, + }, + "required": ["product_id", "quantity"], + "additionalProperties": False, + }, + }, + }, + "required": ["order_id", "items"], + "additionalProperties": False, + }, + }, + }, + id="process-return-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_return_status", + "description": "Finds the status of a return", + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The ID of the order to fetch the return status for", + } + }, + "required": ["order_id"], + "additionalProperties": False, + }, + }, + }, + id="get-return-status-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_recommendations", + "description": "Fetches recommendations based on the user's preferences", + "parameters": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type of place to search recommendations for", + "enum": ["restaurant", "hotel"], + }, + "keywords": { + "type": "array", + "description": "Keywords that should be present in the recommendations", # noqa: E501 + "items": {"type": "string"}, + }, + "location": { + "type": "string", + "description": "The location to search recommendations for", + }, + }, + "required": ["type", "keywords", "location"], + "additionalProperties": False, + }, + }, + }, + id="get-recommendations-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "show_on_map", + "description": "Places pins on the map for relevant locations", + "parameters": { + "type": "object", + "properties": { + "pins": { + "type": "array", + "description": "The pins to place on the map", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the place", + }, + "coordinates": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + }, + "required": ["name", "coordinates"], + "additionalProperties": False, + }, + }, + }, + "required": ["pins"], + "additionalProperties": False, + }, + }, + }, + id="show-on-map-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "fetch_availability", + "description": "Fetches the availability for a given place", + "parameters": { + "type": "object", + "properties": { + "place_id": { + "type": "string", + "description": "The ID of the place to fetch availability for", + } + }, + }, + }, + }, + id="fetch-availability-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "create_booking", + "description": "Creates a booking on the user's behalf", + "parameters": { + "type": "object", + "properties": { + "place_id": { + "type": "string", + "description": "The ID of the place to create a booking for", + }, + "booking_details": { + "anyOf": [ + { + "type": "object", + "description": "Restaurant booking with specific date and time", # noqa: E501 + "properties": { + "date": { + "type": "string", + "description": "The date of the booking, in format YYYY-MM-DD", # noqa: E501 + }, + "time": { + "type": "string", + "description": "The time of the booking, in format HH:MM", # noqa: E501 + }, + }, + "required": ["date", "time"], + }, + { + "type": "object", + "description": "Hotel booking with specific check-in and check-out dates", # noqa: E501 + "properties": { + "check_in": { + "type": "string", + "description": "The check-in date of the booking, in format YYYY-MM-DD", # noqa: E501 + }, + "check_out": { + "type": "string", + "description": "The check-out date of the booking, in format YYYY-MM-DD", # noqa: E501 + }, + }, + "required": ["check_in", "check_out"], + }, + ], + }, + }, + "required": ["place_id", "booking_details"], + "additionalProperties": False, + }, + }, + }, + id="create-booking-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "pick_tshirt_size", + "description": "Call this if the user specifies which size t-shirt they want", + "parameters": { + "type": "object", + "properties": { + "size": { + "type": "string", + "enum": ["s", "m", "l"], + "description": "The size of the t-shirt that the user would like to order", # noqa: E501 + } + }, + "required": ["size"], + "additionalProperties": False, + }, + }, + }, + id="pick-tshirt-size-function", + ), + ], +) +def test_openai_tool_definition_passes_valid_tool_schemas(tool_definition: dict[str, Any]) -> None: + OpenAIToolDefinition.model_validate(tool_definition) From 8ee03e2c5ebd2bfcb68151e8effb6d4a484d6eef Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 00:21:46 -0800 Subject: [PATCH 02/15] refactor --- .../server/api/helpers/prompts/models.py | 85 ++++++++++++++----- 1 file changed, 62 insertions(+), 23 deletions(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index af85e18b77..ce31bd9c23 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -7,6 +7,22 @@ JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]] +class Undefined: + """ + A singleton class that represents an unset or undefined value. Needed since Pydantic + can't natively distinguish between an undefined value and a value that is set to + None. + """ + + def __new__(cls) -> Any: + if not hasattr(cls, "_instance"): + cls._instance = super().__new__(cls) + return cls._instance + + +UNDEFINED: Any = Undefined() + + class PromptTemplateType(str, Enum): STRING = "STR" CHAT = "CHAT" @@ -69,48 +85,71 @@ class PromptToolsV1(PromptModel): tool_definitions: list[PromptToolDefinition] = Field(..., min_length=1) -# Tool models -JSONSchemaDataType = Literal["string", "number", "boolean", "object", "array", "null", "integer"] - +# JSON schema models +JSONSchemaPrimitiveProperty: TypeAlias = Union[ + "JSONSchemaNumberProperty", + "JSONSchemaBooleanProperty", + "JSONSchemaNullProperty", + "JSONSchemaIntegerProperty", + "JSONSchemaStringProperty", +] +JSONSchemaContainerProperty: TypeAlias = Union[ + "JSONSchemaArrayProperty", + "JSONSchemaObjectProperty", +] +JSONSchemaPropertyType: TypeAlias = Union[ + "JSONSchemaPrimitiveProperty", + "JSONSchemaContainerProperty", + "JSONSchemaUnionProperty", +] + + +class JSONSchemaNumberProperty(PromptModel): + type: Literal["number"] + description: str = UNDEFINED -class Undefined: - """ - A singleton class that represents an unset or undefined value. Needed since Pydantic - can't natively distinguish between an undefined value and a value that is set to - None. - """ - def __new__(cls) -> Any: - if not hasattr(cls, "_instance"): - cls._instance = super().__new__(cls) - return cls._instance +class JSONSchemaBooleanProperty(PromptModel): + type: Literal["boolean"] + description: str = UNDEFINED -UNDEFINED: Any = Undefined() +class JSONSchemaNullProperty(PromptModel): + type: Literal["null"] + description: str = UNDEFINED -JSONSchemaPropertyType = Union["JSONSchema", "JSONSchemaProperty", "JSONSchemaPropertyUnion"] +class JSONSchemaIntegerProperty(PromptModel): + type: Literal["integer"] + description: str = UNDEFINED -class JSONSchemaProperty(PromptModel): - type: JSONSchemaDataType +class JSONSchemaStringProperty(PromptModel): + type: Literal["string"] description: str = UNDEFINED - items: JSONSchemaPropertyType = UNDEFINED enum: list[str] = UNDEFINED -class JSONSchemaPropertyUnion(PromptModel): - any_of: list[Union["JSONSchema", "JSONSchemaProperty"]] = Field(UNDEFINED, alias="anyOf") +class JSONSchemaArrayProperty(PromptModel): + type: Literal["array"] + description: str = UNDEFINED + items: JSONSchemaPropertyType = UNDEFINED -class JSONSchema(PromptModel): - type: JSONSchemaDataType +class JSONSchemaObjectProperty(PromptModel): + type: Literal["object"] description: str = UNDEFINED properties: dict[str, JSONSchemaPropertyType] = UNDEFINED required: list[str] = UNDEFINED additional_properties: bool = Field(UNDEFINED, alias="additionalProperties") +class JSONSchemaUnionProperty(PromptModel): + any_of: list[Union["JSONSchemaPrimitiveProperty", "JSONSchemaContainerProperty"]] = Field( + ..., alias="anyOf" + ) + + class OpenAIFunctionDefinition(PromptModel): """ Based on https://github.com/openai/openai-python/blob/1e07c9d839e7e96f02d0a4b745f379a43086334c/src/openai/types/shared_params/function_definition.py#L13 @@ -118,7 +157,7 @@ class OpenAIFunctionDefinition(PromptModel): name: str description: str = UNDEFINED - parameters: JSONSchema = UNDEFINED + parameters: JSONSchemaObjectProperty = UNDEFINED strict: Optional[bool] = UNDEFINED From b60e025212b7da01a9a7e52a00784ad71da57b09 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 00:51:57 -0800 Subject: [PATCH 03/15] use generics --- .../server/api/helpers/prompts/models.py | 59 +++++++++---------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index ce31bd9c23..17707c0bfb 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -1,5 +1,6 @@ +from abc import ABC from enum import Enum -from typing import Any, Literal, Optional, Union +from typing import Any, Generic, Literal, Optional, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field from typing_extensions import TypeAlias @@ -85,7 +86,7 @@ class PromptToolsV1(PromptModel): tool_definitions: list[PromptToolDefinition] = Field(..., min_length=1) -# JSON schema models +# JSON schema JSONSchemaPrimitiveProperty: TypeAlias = Union[ "JSONSchemaNumberProperty", "JSONSchemaBooleanProperty", @@ -97,59 +98,57 @@ class PromptToolsV1(PromptModel): "JSONSchemaArrayProperty", "JSONSchemaObjectProperty", ] -JSONSchemaPropertyType: TypeAlias = Union[ +JSONSchemaProperty: TypeAlias = Union[ "JSONSchemaPrimitiveProperty", "JSONSchemaContainerProperty", - "JSONSchemaUnionProperty", ] +JSONSchemaDataType = TypeVar( + "JSONSchemaDataType", + bound=Literal["number", "boolean", "null", "integer", "string", "array", "object"], +) -class JSONSchemaNumberProperty(PromptModel): - type: Literal["number"] +class BaseJSONSchemaProperty(ABC, Generic[JSONSchemaDataType], PromptModel): + type: JSONSchemaDataType description: str = UNDEFINED -class JSONSchemaBooleanProperty(PromptModel): - type: Literal["boolean"] - description: str = UNDEFINED +class JSONSchemaNumberProperty(BaseJSONSchemaProperty[Literal["number"]]): + pass -class JSONSchemaNullProperty(PromptModel): - type: Literal["null"] - description: str = UNDEFINED +class JSONSchemaBooleanProperty(BaseJSONSchemaProperty[Literal["boolean"]]): + pass -class JSONSchemaIntegerProperty(PromptModel): - type: Literal["integer"] - description: str = UNDEFINED +class JSONSchemaNullProperty(BaseJSONSchemaProperty[Literal["null"]]): + pass -class JSONSchemaStringProperty(PromptModel): - type: Literal["string"] - description: str = UNDEFINED +class JSONSchemaIntegerProperty(BaseJSONSchemaProperty[Literal["integer"]]): + pass + + +class JSONSchemaStringProperty(BaseJSONSchemaProperty[Literal["string"]]): enum: list[str] = UNDEFINED -class JSONSchemaArrayProperty(PromptModel): - type: Literal["array"] - description: str = UNDEFINED - items: JSONSchemaPropertyType = UNDEFINED +class JSONSchemaArrayProperty(BaseJSONSchemaProperty[Literal["array"]]): + items: Union[JSONSchemaProperty, "JSONSchemaAnyOf"] = UNDEFINED -class JSONSchemaObjectProperty(PromptModel): - type: Literal["object"] - description: str = UNDEFINED - properties: dict[str, JSONSchemaPropertyType] = UNDEFINED +class JSONSchemaObjectProperty(BaseJSONSchemaProperty[Literal["object"]]): + properties: dict[str, Union[JSONSchemaProperty, "JSONSchemaAnyOf"]] = UNDEFINED required: list[str] = UNDEFINED additional_properties: bool = Field(UNDEFINED, alias="additionalProperties") -class JSONSchemaUnionProperty(PromptModel): - any_of: list[Union["JSONSchemaPrimitiveProperty", "JSONSchemaContainerProperty"]] = Field( - ..., alias="anyOf" - ) +class JSONSchemaAnyOf(PromptModel): + description: str = UNDEFINED + any_of: list[JSONSchemaProperty] = Field(..., alias="anyOf") +# OpenAI tool definitions class OpenAIFunctionDefinition(PromptModel): """ Based on https://github.com/openai/openai-python/blob/1e07c9d839e7e96f02d0a4b745f379a43086334c/src/openai/types/shared_params/function_definition.py#L13 From ee1b1197a34405d5d95a3b5f1ad7cd21844398a3 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 12:13:51 -0800 Subject: [PATCH 04/15] tests --- .../server/api/helpers/prompts/models.py | 57 ++--- tests/unit/server/api/helpers/test_models.py | 197 ++++++++++++++++++ 2 files changed, 230 insertions(+), 24 deletions(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index 17707c0bfb..3ac8eaca80 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -1,8 +1,7 @@ -from abc import ABC from enum import Enum -from typing import Any, Generic, Literal, Optional, TypeVar, Union +from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import TypeAlias JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]] @@ -102,46 +101,56 @@ class PromptToolsV1(PromptModel): "JSONSchemaPrimitiveProperty", "JSONSchemaContainerProperty", ] -JSONSchemaDataType = TypeVar( - "JSONSchemaDataType", - bound=Literal["number", "boolean", "null", "integer", "string", "array", "object"], -) -class BaseJSONSchemaProperty(ABC, Generic[JSONSchemaDataType], PromptModel): - type: JSONSchemaDataType +class JSONSchemaNumberProperty(PromptModel): + type: Literal["number"] description: str = UNDEFINED -class JSONSchemaNumberProperty(BaseJSONSchemaProperty[Literal["number"]]): - pass - - -class JSONSchemaBooleanProperty(BaseJSONSchemaProperty[Literal["boolean"]]): - pass +class JSONSchemaBooleanProperty(PromptModel): + type: Literal["boolean"] + description: str = UNDEFINED -class JSONSchemaNullProperty(BaseJSONSchemaProperty[Literal["null"]]): - pass +class JSONSchemaNullProperty(PromptModel): + type: Literal["null"] + description: str = UNDEFINED -class JSONSchemaIntegerProperty(BaseJSONSchemaProperty[Literal["integer"]]): - pass +class JSONSchemaIntegerProperty(PromptModel): + type: Literal["integer"] + description: str = UNDEFINED -class JSONSchemaStringProperty(BaseJSONSchemaProperty[Literal["string"]]): +class JSONSchemaStringProperty(PromptModel): + type: Literal["string"] + description: str = UNDEFINED enum: list[str] = UNDEFINED -class JSONSchemaArrayProperty(BaseJSONSchemaProperty[Literal["array"]]): - items: Union[JSONSchemaProperty, "JSONSchemaAnyOf"] = UNDEFINED +class JSONSchemaArrayProperty(PromptModel): + type: Literal["array"] + description: str = UNDEFINED + items: Union[JSONSchemaProperty, "JSONSchemaAnyOf"] -class JSONSchemaObjectProperty(BaseJSONSchemaProperty[Literal["object"]]): - properties: dict[str, Union[JSONSchemaProperty, "JSONSchemaAnyOf"]] = UNDEFINED +class JSONSchemaObjectProperty(PromptModel): + type: Literal["object"] + description: str = UNDEFINED + properties: dict[str, Union[JSONSchemaProperty, "JSONSchemaAnyOf"]] required: list[str] = UNDEFINED additional_properties: bool = Field(UNDEFINED, alias="additionalProperties") + @model_validator(mode="after") + def ensure_required_fields_are_included_in_properties(self) -> "JSONSchemaObjectProperty": + if self.required is UNDEFINED: + return self + invalid_fields = [field for field in self.required if field not in self.properties] + if invalid_fields: + raise ValueError(f"Required fields {invalid_fields} are not defined in properties") + return self + class JSONSchemaAnyOf(PromptModel): description: str = UNDEFINED diff --git a/tests/unit/server/api/helpers/test_models.py b/tests/unit/server/api/helpers/test_models.py index e29ee2b028..17b2d10efa 100644 --- a/tests/unit/server/api/helpers/test_models.py +++ b/tests/unit/server/api/helpers/test_models.py @@ -1,6 +1,7 @@ from typing import Any import pytest +from pydantic import ValidationError from phoenix.server.api.helpers.prompts.models import OpenAIToolDefinition @@ -495,7 +496,203 @@ }, id="pick-tshirt-size-function", ), + pytest.param( + { + "type": "function", + "function": { + "name": "test_primitives", + "description": "Test all primitive types", + "parameters": { + "type": "object", + "properties": { + "string_field": {"type": "string", "description": "A string field"}, + "number_field": {"type": "number", "description": "A number field"}, + "integer_field": {"type": "integer", "description": "An integer field"}, + "boolean_field": {"type": "boolean", "description": "A boolean field"}, + "null_field": {"type": "null", "description": "A null field"}, + }, + "required": [ + "string_field", + "number_field", + "integer_field", + "boolean_field", + "null_field", + ], + "additionalProperties": False, + }, + }, + }, + id="primitive-types-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "update_user_profile", + "description": "Updates a user's profile information", + "parameters": { + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "The ID of the user to update", + }, + "nickname": { + "description": "Optional nickname that can be null or a string", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + }, + "required": ["user_id"], + "additionalProperties": False, + }, + }, + }, + id="optional-anyof-parameter", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "process_user_preferences", + "description": "Process optional user preferences", + "parameters": { + "type": "object", + "properties": { + "email_notifications": { + "description": "Email notification preferences that can be null or boolean", # noqa: E501 + "anyOf": [{"type": "boolean"}, {"type": "null"}], + }, + "display_name": { + "description": "Display name that can be null or string", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + "age": { + "description": "Age that can be null or integer", + "anyOf": [{"type": "integer"}, {"type": "null"}], + }, + }, + "additionalProperties": False, + }, + }, + }, + id="array-of-optional-parameters", + ), ], ) def test_openai_tool_definition_passes_valid_tool_schemas(tool_definition: dict[str, Any]) -> None: OpenAIToolDefinition.model_validate(tool_definition) + + +@pytest.mark.parametrize( + "tool_definition", + [ + pytest.param( + { + "type": "function", + "function": { + "name": "pick_tshirt_size", + "description": "Call this if the user specifies which size t-shirt they want", + "parameters": { + "type": "object", + "properties": { + "size": { + "type": "invalid_type", + "enum": ["s", "m", "l"], + "description": "The size of the t-shirt that the user would like to order", # noqa: E501 + } + }, + "required": ["size"], + "additionalProperties": False, + }, + }, + }, + id="invalid-data-type", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Sets the temperature for the thermostat", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "number", + "enum": ["70", "72", "74"], # only string properties can have enums + "description": "The temperature to set in Fahrenheit", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="number-property-with-invalid-enum", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "extra": "extra", # extra properties are not allowed + }, + }, + }, + id="extra-properties", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "update_user", + "description": "Updates user information", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": [ + "name", + "email", # email is not in properties + ], + "additionalProperties": False, + }, + }, + }, + id="required-field-not-in-properties", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_preferences", + "parameters": { + "type": "object", + "properties": { + "priority": { + "type": "string", + "enum": [ + 0, # integer enum values not allowed + "low", + "medium", + "high", + ], + "description": "The priority level to set", + } + }, + "required": ["priority"], + "additionalProperties": False, + }, + }, + }, + id="string-property-with-priority-enum", + ), + ], +) +def test_openai_tool_definition_fails_invalid_tool_schemas(tool_definition: dict[str, Any]) -> None: + with pytest.raises(ValidationError): + OpenAIToolDefinition.model_validate(tool_definition) From c8981ab87c39731a9e33be9d35d9ba6f1c3ce826 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 12:18:25 -0800 Subject: [PATCH 05/15] array of anyof --- tests/unit/server/api/helpers/test_models.py | 33 +++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/unit/server/api/helpers/test_models.py b/tests/unit/server/api/helpers/test_models.py index 17b2d10efa..f6245a31c2 100644 --- a/tests/unit/server/api/helpers/test_models.py +++ b/tests/unit/server/api/helpers/test_models.py @@ -553,29 +553,32 @@ { "type": "function", "function": { - "name": "process_user_preferences", - "description": "Process optional user preferences", + "name": "categorize_colors", + "description": "Categorize colors into warm, cool, or neutral tones, with null for uncertain cases", # noqa: E501 "parameters": { "type": "object", "properties": { - "email_notifications": { - "description": "Email notification preferences that can be null or boolean", # noqa: E501 - "anyOf": [{"type": "boolean"}, {"type": "null"}], - }, - "display_name": { - "description": "Display name that can be null or string", - "anyOf": [{"type": "string"}, {"type": "null"}], - }, - "age": { - "description": "Age that can be null or integer", - "anyOf": [{"type": "integer"}, {"type": "null"}], - }, + "colors": { + "type": "array", + "description": "List of color categories, with null for uncertain colors", # noqa: E501 + "items": { + "anyOf": [ + { + "type": "string", + "enum": ["warm", "cool", "neutral"], + "description": "Color category", + }, + {"type": "null"}, + ] + }, + } }, + "required": ["colors"], "additionalProperties": False, }, }, }, - id="array-of-optional-parameters", + id="array-of-optional-enums", ), ], ) From 5e8aef6f5c84f239fb8b7d091d47e3becdb70a28 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 12:25:23 -0800 Subject: [PATCH 06/15] duplicate enum value --- .../server/api/helpers/prompts/models.py | 10 ++++++- tests/unit/server/api/helpers/test_models.py | 26 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index 3ac8eaca80..d38cadc701 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from typing_extensions import TypeAlias JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]] @@ -128,6 +128,14 @@ class JSONSchemaStringProperty(PromptModel): description: str = UNDEFINED enum: list[str] = UNDEFINED + @field_validator("enum") + def ensure_unique_enum_values(cls, enum_values: list[str]) -> list[str]: + if enum_values is UNDEFINED: + return enum_values + if len(enum_values) != len(set(enum_values)): + raise ValueError("Enum values must be unique") + return enum_values + class JSONSchemaArrayProperty(PromptModel): type: Literal["array"] diff --git a/tests/unit/server/api/helpers/test_models.py b/tests/unit/server/api/helpers/test_models.py index f6245a31c2..27b613f50c 100644 --- a/tests/unit/server/api/helpers/test_models.py +++ b/tests/unit/server/api/helpers/test_models.py @@ -694,6 +694,32 @@ def test_openai_tool_definition_passes_valid_tool_schemas(tool_definition: dict[ }, id="string-property-with-priority-enum", ), + pytest.param( + { + "type": "function", + "function": { + "name": "select_color", + "description": "Select a color from the available options", + "parameters": { + "type": "object", + "properties": { + "color": { + "type": "string", + "enum": [ + "red", + "blue", + "red", # duplicate enum value + ], + "description": "The color to select", + } + }, + "required": ["color"], + "additionalProperties": False, + }, + }, + }, + id="duplicate-enum-values", + ), ], ) def test_openai_tool_definition_fails_invalid_tool_schemas(tool_definition: dict[str, Any]) -> None: From bcd2c842f29afd3ccf4fa1f849ed5402b869e2e7 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 12:56:50 -0800 Subject: [PATCH 07/15] integer min max --- .../server/api/helpers/prompts/models.py | 36 +++++- tests/unit/server/api/helpers/test_models.py | 114 ++++++++++++++++++ 2 files changed, 144 insertions(+), 6 deletions(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index d38cadc701..3b9e3ebf4a 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -87,10 +87,10 @@ class PromptToolsV1(PromptModel): # JSON schema JSONSchemaPrimitiveProperty: TypeAlias = Union[ + "JSONSchemaIntegerProperty", "JSONSchemaNumberProperty", "JSONSchemaBooleanProperty", "JSONSchemaNullProperty", - "JSONSchemaIntegerProperty", "JSONSchemaStringProperty", ] JSONSchemaContainerProperty: TypeAlias = Union[ @@ -103,9 +103,38 @@ class PromptToolsV1(PromptModel): ] +class JSONSchemaIntegerProperty(PromptModel): + type: Literal["integer"] + description: str = UNDEFINED + minimum: int = UNDEFINED + maximum: int = UNDEFINED + + @model_validator(mode="after") + def ensure_minimum_lte_maximum(self) -> "JSONSchemaIntegerProperty": + if ( + self.minimum is not UNDEFINED + and self.maximum is not UNDEFINED + and self.minimum > self.maximum + ): + raise ValueError("minimum must be less than or equal to maximum") + return self + + class JSONSchemaNumberProperty(PromptModel): type: Literal["number"] description: str = UNDEFINED + minimum: float = UNDEFINED + maximum: float = UNDEFINED + + @model_validator(mode="after") + def ensure_minimum_lte_maximum(self) -> "JSONSchemaNumberProperty": + if ( + self.minimum is not UNDEFINED + and self.maximum is not UNDEFINED + and self.minimum > self.maximum + ): + raise ValueError("minimum must be less than or equal to maximum") + return self class JSONSchemaBooleanProperty(PromptModel): @@ -118,11 +147,6 @@ class JSONSchemaNullProperty(PromptModel): description: str = UNDEFINED -class JSONSchemaIntegerProperty(PromptModel): - type: Literal["integer"] - description: str = UNDEFINED - - class JSONSchemaStringProperty(PromptModel): type: Literal["string"] description: str = UNDEFINED diff --git a/tests/unit/server/api/helpers/test_models.py b/tests/unit/server/api/helpers/test_models.py index 27b613f50c..627d290a5c 100644 --- a/tests/unit/server/api/helpers/test_models.py +++ b/tests/unit/server/api/helpers/test_models.py @@ -580,6 +580,52 @@ }, id="array-of-optional-enums", ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Set temperature within valid range", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "integer", + "minimum": 0, + "maximum": 100, + "description": "Temperature in Fahrenheit (0-100)", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="integer-min-max-constraints", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Set temperature within valid range", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "number", + "minimum": 0.5, # float min + "maximum": 100, # integer max + "description": "Temperature in Fahrenheit (0-100)", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="number-min-max-constraints", + ), ], ) def test_openai_tool_definition_passes_valid_tool_schemas(tool_definition: dict[str, Any]) -> None: @@ -720,6 +766,74 @@ def test_openai_tool_definition_passes_valid_tool_schemas(tool_definition: dict[ }, id="duplicate-enum-values", ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Set temperature with invalid range", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "integer", + "minimum": 100, + "maximum": 0, # min > max + "description": "Temperature in Celsius", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="integer-min-max-range", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_count", + "description": "Set an integer count with float bounds", + "parameters": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "minimum": 1.4, # float not allowed for integer property + "description": "Count value", + } + }, + "required": ["count"], + "additionalProperties": False, + }, + }, + }, + id="integer-float-bounds", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Set temperature with invalid range", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "number", + "minimum": 100, + "maximum": 0, # min > max + "description": "Temperature in Celsius", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="number-min-max-range", + ), ], ) def test_openai_tool_definition_fails_invalid_tool_schemas(tool_definition: dict[str, Any]) -> None: From 35d3589a910ec0cac623449bab11b43bf0da3764 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 16:02:01 -0800 Subject: [PATCH 08/15] wire up openai tool schema check for create prompt mutation --- .../server/api/helpers/prompts/models.py | 30 ++++++- .../server/api/mutations/prompt_mutations.py | 46 ++++++---- .../api/mutations/test_prompt_mutations.py | 90 ++++++++++++++++++- 3 files changed, 142 insertions(+), 24 deletions(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index 3b9e3ebf4a..89fa345ad6 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator from typing_extensions import TypeAlias JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]] @@ -82,7 +82,33 @@ class PromptToolDefinition(PromptModel): class PromptToolsV1(PromptModel): version: Literal["tools-v1"] = "tools-v1" - tool_definitions: list[PromptToolDefinition] = Field(..., min_length=1) + tool_definitions: list[PromptToolDefinition] + + +class PromptVersion(PromptModel): + user_id: Optional[int] + description: Optional[str] + template_type: PromptTemplateType + template_format: PromptTemplateFormat + template: PromptTemplate + invocation_parameters: Optional[dict[str, Any]] + tools: PromptToolsV1 + output_schema: Optional[dict[str, Any]] + model_name: str + model_provider: str + + @model_validator(mode="after") + def validate_tool_definitions_for_known_model_providers(self) -> "PromptVersion": + tool_definitions = [tool_def.definition for tool_def in self.tools.tool_definitions] + if self.model_provider.lower() == "openai": + for tool_definition_index, tool_definition in enumerate(tool_definitions): + try: + OpenAIToolDefinition.model_validate(tool_definition) + except ValidationError as e: + raise ValueError( + f"Invalid OpenAI tool definition at index {tool_definition_index}: {e}" + ) + return self # JSON schema diff --git a/src/phoenix/server/api/mutations/prompt_mutations.py b/src/phoenix/server/api/mutations/prompt_mutations.py index 4c373f376e..d3040f3a17 100644 --- a/src/phoenix/server/api/mutations/prompt_mutations.py +++ b/src/phoenix/server/api/mutations/prompt_mutations.py @@ -16,6 +16,7 @@ PromptChatTemplateV1, PromptJSONSchema, PromptToolsV1, + PromptVersion, ) from phoenix.server.api.input_types.PromptVersionInput import ChatPromptVersionInput from phoenix.server.api.queries import Query @@ -48,16 +49,18 @@ class PromptMutationMixin: async def create_chat_prompt( self, info: Info[Context, None], input: CreateChatPromptInput ) -> Prompt: + user_id: Optional[int] = None + assert isinstance(request := info.context.request, Request) + if "user" in request.scope: + assert isinstance(user := request.user, PhoenixUser) + user_id = int(user.identity) + try: tool_definitions = [] for tool in input.prompt_version.tools: pydantic_tool = tool.to_pydantic() - tool_definitions.append(pydantic_tool.dict()) - tools = ( - PromptToolsV1(tool_definitions=tool_definitions).dict() - if tool_definitions - else None - ) + tool_definitions.append(pydantic_tool) + tools = PromptToolsV1(tool_definitions=tool_definitions) output_schema = ( PromptJSONSchema.model_validate( strawberry.asdict(input.prompt_version.output_schema) @@ -68,26 +71,33 @@ async def create_chat_prompt( template = PromptChatTemplateV1.model_validate( strawberry.asdict(input.prompt_version.template) ).dict() - except ValidationError as error: - raise BadRequest(str(error)) - - user_id: Optional[int] = None - assert isinstance(request := info.context.request, Request) - if "user" in request.scope: - assert isinstance(user := request.user, PhoenixUser) - user_id = int(user.identity) - async with info.context.db() as session: - prompt_version = models.PromptVersion( - description=input.prompt_version.description, + pydantic_prompt_version = PromptVersion( user_id=user_id, + description=input.prompt_version.description, template_type=input.prompt_version.template_type.value, template_format=input.prompt_version.template_format.value, template=template, invocation_parameters=input.prompt_version.invocation_parameters, tools=tools, output_schema=output_schema, - model_provider=input.prompt_version.model_provider, model_name=input.prompt_version.model_name, + model_provider=input.prompt_version.model_provider, + ) + except ValidationError as error: + raise BadRequest(str(error)) + + async with info.context.db() as session: + prompt_version = models.PromptVersion( + description=pydantic_prompt_version.description, + user_id=pydantic_prompt_version.user_id, + template_type=pydantic_prompt_version.template_type, + template_format=pydantic_prompt_version.template_format, + template=pydantic_prompt_version.template.dict(), + invocation_parameters=pydantic_prompt_version.invocation_parameters, + tools=pydantic_prompt_version.tools.dict(), + output_schema=pydantic_prompt_version.output_schema, + model_provider=pydantic_prompt_version.model_provider, + model_name=pydantic_prompt_version.model_name, ) prompt = models.Prompt( name=input.name, diff --git a/tests/unit/server/api/mutations/test_prompt_mutations.py b/tests/unit/server/api/mutations/test_prompt_mutations.py index eb30d7274a..773b2682ac 100644 --- a/tests/unit/server/api/mutations/test_prompt_mutations.py +++ b/tests/unit/server/api/mutations/test_prompt_mutations.py @@ -124,8 +124,8 @@ class TestPromptMutations: "templateFormat": "MUSTACHE", "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, - "modelProvider": "openai", - "modelName": "o1-mini", + "modelProvider": "anthropic", + "modelName": "claude-3-5-sonnet", "tools": [{"definition": {"foo": "bar"}}], }, } @@ -134,6 +134,53 @@ class TestPromptMutations: None, id="with-tools", ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + }, + } + }, + [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + None, + id="with-valid-openai-tools", + ), pytest.param( { "input": { @@ -180,8 +227,10 @@ async def test_create_chat_prompt_succeeds_with_valid_input( assert prompt_version.pop("user") is None assert prompt_version.pop("templateType") == "CHAT" assert prompt_version.pop("templateFormat") == "MUSTACHE" - assert prompt_version.pop("modelProvider") == "openai" - assert prompt_version.pop("modelName") == "o1-mini" + expected_model_provider = variables["input"]["promptVersion"]["modelProvider"] + expected_model_name = variables["input"]["promptVersion"]["modelName"] + assert prompt_version.pop("modelProvider") == expected_model_provider + assert prompt_version.pop("modelName") == expected_model_name assert prompt_version.pop("invocationParameters") == {"temperature": 0.4} assert prompt_version.pop("tools") == expected_tools assert prompt_version.pop("outputSchema") == expected_output_schema @@ -272,6 +321,39 @@ async def test_create_chat_prompt_fails_on_name_conflict( "Input should be a valid dictionary", id="invalid-output-schema", ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "invalid_type", # invalid schema type + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + }, + } + }, + "function.parameters.type", + id="with-invalid-openai-tools", + ), ], ) async def test_create_chat_prompt_fails_with_invalid_input( From e090f5edfaf8edde15c902da7487a5b78ebc3cec Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 16:16:58 -0800 Subject: [PATCH 09/15] wire up create prompt version --- .../server/api/mutations/prompt_mutations.py | 51 ++++++----- .../api/mutations/test_prompt_mutations.py | 88 ++++++++++++++++++- 2 files changed, 114 insertions(+), 25 deletions(-) diff --git a/src/phoenix/server/api/mutations/prompt_mutations.py b/src/phoenix/server/api/mutations/prompt_mutations.py index d3040f3a17..4f2037ae12 100644 --- a/src/phoenix/server/api/mutations/prompt_mutations.py +++ b/src/phoenix/server/api/mutations/prompt_mutations.py @@ -117,16 +117,18 @@ async def create_chat_prompt_version( info: Info[Context, None], input: CreateChatPromptVersionInput, ) -> Prompt: + user_id: Optional[int] = None + assert isinstance(request := info.context.request, Request) + if "user" in request.scope: + assert isinstance(user := request.user, PhoenixUser) + user_id = int(user.identity) + try: tool_definitions = [] for tool in input.prompt_version.tools: pydantic_tool = tool.to_pydantic() - tool_definitions.append(pydantic_tool.dict()) - tools = ( - PromptToolsV1(tool_definitions=tool_definitions).dict() - if tool_definitions - else None - ) + tool_definitions.append(pydantic_tool) + tools = PromptToolsV1(tool_definitions=tool_definitions) output_schema = ( PromptJSONSchema.model_validate( strawberry.asdict(input.prompt_version.output_schema) @@ -137,14 +139,21 @@ async def create_chat_prompt_version( template = PromptChatTemplateV1.model_validate( strawberry.asdict(input.prompt_version.template) ).dict() + pydantic_prompt_version = PromptVersion( + user_id=user_id, + description=input.prompt_version.description, + template_type=input.prompt_version.template_type.value, + template_format=input.prompt_version.template_format.value, + template=template, + invocation_parameters=input.prompt_version.invocation_parameters, + tools=tools, + output_schema=output_schema, + model_name=input.prompt_version.model_name, + model_provider=input.prompt_version.model_provider, + ) except ValidationError as error: raise BadRequest(str(error)) - user_id: Optional[int] = None - assert isinstance(request := info.context.request, Request) - if "user" in request.scope: - assert isinstance(user := request.user, PhoenixUser) - user_id = int(user.identity) prompt_id = from_global_id_with_expected_type( global_id=input.prompt_id, expected_type_name=Prompt.__name__ ) @@ -155,16 +164,16 @@ async def create_chat_prompt_version( prompt_version = models.PromptVersion( prompt_id=prompt_id, - user_id=user_id, - description=input.prompt_version.description, - template_type=input.prompt_version.template_type.value, - template_format=input.prompt_version.template_format.value, - template=template, - invocation_parameters=input.prompt_version.invocation_parameters, - tools=tools, - output_schema=output_schema, - model_provider=input.prompt_version.model_provider, - model_name=input.prompt_version.model_name, + user_id=pydantic_prompt_version.user_id, + description=pydantic_prompt_version.description, + template_type=pydantic_prompt_version.template_type, + template_format=pydantic_prompt_version.template_format, + template=pydantic_prompt_version.template.dict(), + invocation_parameters=pydantic_prompt_version.invocation_parameters, + tools=pydantic_prompt_version.tools.dict(), + output_schema=pydantic_prompt_version.output_schema, + model_provider=pydantic_prompt_version.model_provider, + model_name=pydantic_prompt_version.model_name, ) session.add(prompt_version) diff --git a/tests/unit/server/api/mutations/test_prompt_mutations.py b/tests/unit/server/api/mutations/test_prompt_mutations.py index 773b2682ac..6d5d3426ad 100644 --- a/tests/unit/server/api/mutations/test_prompt_mutations.py +++ b/tests/unit/server/api/mutations/test_prompt_mutations.py @@ -396,8 +396,8 @@ async def test_create_chat_prompt_fails_with_invalid_input( "templateFormat": "MUSTACHE", "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, - "modelProvider": "openai", - "modelName": "o1-mini", + "modelProvider": "anthropic", + "modelName": "claude-3-5-sonnet", "tools": [{"definition": {"foo": "bar"}}], }, } @@ -406,6 +406,52 @@ async def test_create_chat_prompt_fails_with_invalid_input( None, id="with-tools", ), + pytest.param( + { + "input": { + "promptId": str(GlobalID("Prompt", "1")), + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + }, + } + }, + [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + None, + id="with-valid-openai-tools", + ), pytest.param( { "input": { @@ -473,8 +519,10 @@ async def test_create_chat_prompt_version_succeeds_with_valid_input( assert latest_prompt_version.pop("user") is None assert latest_prompt_version.pop("templateType") == "CHAT" assert latest_prompt_version.pop("templateFormat") == "MUSTACHE" - assert latest_prompt_version.pop("modelProvider") == "openai" - assert latest_prompt_version.pop("modelName") == "o1-mini" + expected_model_provider = variables["input"]["promptVersion"]["modelProvider"] + expected_model_name = variables["input"]["promptVersion"]["modelName"] + assert latest_prompt_version.pop("modelProvider") == expected_model_provider + assert latest_prompt_version.pop("modelName") == expected_model_name assert latest_prompt_version.pop("invocationParameters") == {"temperature": 0.4} assert latest_prompt_version.pop("tools") == expected_tools assert latest_prompt_version.pop("outputSchema") == expected_output_schema @@ -558,6 +606,38 @@ async def test_create_chat_prompt_version_fails_with_nonexistent_prompt_id( "Input should be a valid dictionary", id="invalid-output-schema", ), + pytest.param( + { + "input": { + "promptId": str(GlobalID("Prompt", "1")), + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "invalid_type", # invalid schema type + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + }, + } + }, + "function.parameters.type", + id="with-invalid-openai-tools", + ), ], ) async def test_create_chat_prompt_version_fails_with_invalid_input( From 806b15163dca3de01ad3edeb8240f0409145530d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 Jan 2025 16:28:46 -0800 Subject: [PATCH 10/15] update schema --- schemas/openapi.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schemas/openapi.json b/schemas/openapi.json index e485d01171..58d620441f 100644 --- a/schemas/openapi.json +++ b/schemas/openapi.json @@ -1918,6 +1918,7 @@ "title": "Definition" } }, + "additionalProperties": false, "type": "object", "required": [ "definition" @@ -1993,7 +1994,6 @@ "$ref": "#/components/schemas/PromptToolDefinition" }, "type": "array", - "minItems": 1, "title": "Tool Definitions" } }, From aa188dfd235a80ae9c714f0f17404dd694e44def Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 8 Jan 2025 14:11:13 -0800 Subject: [PATCH 11/15] anthropic --- .../server/api/helpers/prompts/models.py | 20 +++ tests/unit/server/api/helpers/test_models.py | 123 +++++++++++++++++- 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index 89fa345ad6..4098917e1b 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -234,3 +234,23 @@ class OpenAIToolDefinition(PromptModel): function: OpenAIFunctionDefinition type: Literal["function"] + + +class AnthropicCacheControlEphemeralParam(PromptModel): + """ + Based on https://github.com/anthropics/anthropic-sdk-python/blob/93cbbbde964e244f02bf1bd2b579c5fabce4e267/src/anthropic/types/cache_control_ephemeral_param.py#L10 + """ + + type: Literal["ephemeral"] + + +# Anthropic tool definitions +class AnthropicToolDefinition(PromptModel): + """ + Based on https://github.com/anthropics/anthropic-sdk-python/blob/93cbbbde964e244f02bf1bd2b579c5fabce4e267/src/anthropic/types/tool_param.py#L22 + """ + + input_schema: JSONSchemaObjectProperty + name: str + cache_control: Optional[AnthropicCacheControlEphemeralParam] = UNDEFINED + description: str = UNDEFINED diff --git a/tests/unit/server/api/helpers/test_models.py b/tests/unit/server/api/helpers/test_models.py index 627d290a5c..d253e84f98 100644 --- a/tests/unit/server/api/helpers/test_models.py +++ b/tests/unit/server/api/helpers/test_models.py @@ -3,7 +3,7 @@ import pytest from pydantic import ValidationError -from phoenix.server.api.helpers.prompts.models import OpenAIToolDefinition +from phoenix.server.api.helpers.prompts.models import AnthropicToolDefinition, OpenAIToolDefinition @pytest.mark.parametrize( @@ -839,3 +839,124 @@ def test_openai_tool_definition_passes_valid_tool_schemas(tool_definition: dict[ def test_openai_tool_definition_fails_invalid_tool_schemas(tool_definition: dict[str, Any]) -> None: with pytest.raises(ValidationError): OpenAIToolDefinition.model_validate(tool_definition) + + +@pytest.mark.parametrize( + "tool_definition", + [ + pytest.param( + { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + }, + id="get-weather-function", + ), + pytest.param( + { + "name": "get_time", + "description": "Get the current time in a given time zone", + "input_schema": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The IANA time zone name, e.g. America/Los_Angeles", + } + }, + "required": ["timezone"], + }, + }, + id="get-time-function", + ), + pytest.param( + { + "name": "get_location", + "description": "Get the current user location based on their IP address. This tool has no parameters or arguments.", # noqa: E501 + "input_schema": {"type": "object", "properties": {}}, + }, + id="get-location-function", + ), + pytest.param( + { + "name": "record_summary", + "description": "Record summary of an image using well-structured JSON.", + "input_schema": { + "type": "object", + "properties": { + "key_colors": { + "type": "array", + "items": { + "type": "object", + "properties": { + "r": {"type": "number", "description": "red value [0.0, 1.0]"}, + "g": { + "type": "number", + "description": "green value [0.0, 1.0]", + }, + "b": {"type": "number", "description": "blue value [0.0, 1.0]"}, + "name": { + "type": "string", + "description": 'Human-readable color name in snake_case, e.g. "olive_green" or "turquoise"', # noqa: E501 + }, + }, + "required": ["r", "g", "b", "name"], + }, + "description": "Key colors in the image. Limit to less then four.", + }, + "description": { + "type": "string", + "description": "Image description. One to two sentences max.", + }, + "estimated_year": { + "type": "integer", + "description": "Estimated year that the images was taken, if is it a photo. Only set this if the image appears to be non-fictional. Rough estimates are okay!", # noqa: E501 + }, + }, + "required": ["key_colors", "description"], + }, + }, + id="record-image-summary", + ), + pytest.param( + { + "name": "get_weather", + "input_schema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + "cache_control": {"type": "ephemeral"}, + }, + id="get-weather-function-cache-control-ephemeral", + ), + pytest.param( + { + "name": "get_weather", + "input_schema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + "cache_control": None, + }, + id="get-weather-function-cache-control-none", + ), + ], +) +def test_anthropic_tool_definition_passes_valid_tool_schemas( + tool_definition: dict[str, Any], +) -> None: + AnthropicToolDefinition.model_validate(tool_definition) From f685d69a87ced84b8b1c5684206ca8988dc671c8 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 8 Jan 2025 14:46:13 -0800 Subject: [PATCH 12/15] wire up anthropic --- .../server/api/helpers/prompts/models.py | 8 + .../api/mutations/test_prompt_mutations.py | 226 +++++++++++++++++- 2 files changed, 230 insertions(+), 4 deletions(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index 4098917e1b..6de4243d3a 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -108,6 +108,14 @@ def validate_tool_definitions_for_known_model_providers(self) -> "PromptVersion" raise ValueError( f"Invalid OpenAI tool definition at index {tool_definition_index}: {e}" ) + if self.model_provider.lower() == "anthropic": + for tool_definition_index, tool_definition in enumerate(tool_definitions): + try: + AnthropicToolDefinition.model_validate(tool_definition) + except ValidationError as e: + raise ValueError( + f"Invalid Anthropic tool definition at index {tool_definition_index}: {e}" + ) return self diff --git a/tests/unit/server/api/mutations/test_prompt_mutations.py b/tests/unit/server/api/mutations/test_prompt_mutations.py index 6d5d3426ad..1ae94dd847 100644 --- a/tests/unit/server/api/mutations/test_prompt_mutations.py +++ b/tests/unit/server/api/mutations/test_prompt_mutations.py @@ -124,8 +124,8 @@ class TestPromptMutations: "templateFormat": "MUSTACHE", "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, - "modelProvider": "anthropic", - "modelName": "claude-3-5-sonnet", + "modelProvider": "unknown", + "modelName": "unknown", "tools": [{"definition": {"foo": "bar"}}], }, } @@ -181,6 +181,71 @@ class TestPromptMutations: None, id="with-valid-openai-tools", ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "anthropic", + "modelName": "claude-2", + "tools": [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", # noqa: E501 + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", # noqa: E501 + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + } + } + ], + }, + } + }, + [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + } + } + ], + None, + id="with-valid-anthropic-tools", + ), pytest.param( { "input": { @@ -354,6 +419,51 @@ async def test_create_chat_prompt_fails_on_name_conflict( "function.parameters.type", id="with-invalid-openai-tools", ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "anthropic", + "modelName": "claude-2", + "tools": [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", # noqa: E501 + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", # noqa: E501 + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + "cache_control": { + "type": "invalid_type" + }, # invalid cache control type + } + } + ], + }, + } + }, + "cache_control.type", + id="with-invalid-anthropic-tools", + ), ], ) async def test_create_chat_prompt_fails_with_invalid_input( @@ -396,8 +506,8 @@ async def test_create_chat_prompt_fails_with_invalid_input( "templateFormat": "MUSTACHE", "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, - "modelProvider": "anthropic", - "modelName": "claude-3-5-sonnet", + "modelProvider": "unknown", + "modelName": "unknown", "tools": [{"definition": {"foo": "bar"}}], }, } @@ -472,6 +582,70 @@ async def test_create_chat_prompt_fails_with_invalid_input( {"definition": {"foo": "bar"}}, id="with-output-schema", ), + pytest.param( + { + "input": { + "promptId": str(GlobalID("Prompt", "1")), + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "anthropic", + "modelName": "claude-2", + "tools": [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", # noqa: E501 + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", # noqa: E501 + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + } + } + ], + }, + } + }, + [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + } + } + ], + None, + id="with-valid-anthropic-tools", + ), ], ) async def test_create_chat_prompt_version_succeeds_with_valid_input( @@ -638,6 +812,50 @@ async def test_create_chat_prompt_version_fails_with_nonexistent_prompt_id( "function.parameters.type", id="with-invalid-openai-tools", ), + pytest.param( + { + "input": { + "promptId": str(GlobalID("Prompt", "1")), + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "anthropic", + "modelName": "claude-2", + "tools": [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", # noqa: E501 + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", # noqa: E501 + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + "cache_control": { + "type": "invalid_type" + }, # invalid cache control type + } + } + ], + }, + } + }, + "cache_control.type", + id="with-invalid-anthropic-tools", + ), ], ) async def test_create_chat_prompt_version_fails_with_invalid_input( From 5a9f8990e88798454f0ccd2156795def6d8d4708 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 8 Jan 2025 14:56:02 -0800 Subject: [PATCH 13/15] simplify tests --- .../api/mutations/test_prompt_mutations.py | 104 ++---------------- 1 file changed, 7 insertions(+), 97 deletions(-) diff --git a/tests/unit/server/api/mutations/test_prompt_mutations.py b/tests/unit/server/api/mutations/test_prompt_mutations.py index 1ae94dd847..c0b801dc38 100644 --- a/tests/unit/server/api/mutations/test_prompt_mutations.py +++ b/tests/unit/server/api/mutations/test_prompt_mutations.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import pytest from strawberry.relay.types import GlobalID @@ -91,7 +91,7 @@ class TestPromptMutations: """ @pytest.mark.parametrize( - "variables,expected_tools,expected_output_schema", + "variables", [ pytest.param( { @@ -109,8 +109,6 @@ class TestPromptMutations: }, } }, - [], - None, id="basic-input", ), pytest.param( @@ -130,8 +128,6 @@ class TestPromptMutations: }, } }, - [{"definition": {"foo": "bar"}}], - None, id="with-tools", ), pytest.param( @@ -164,21 +160,6 @@ class TestPromptMutations: }, } }, - [ - { - "definition": { - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - }, - }, - } - } - ], - None, id="with-valid-openai-tools", ), pytest.param( @@ -220,30 +201,6 @@ class TestPromptMutations: }, } }, - [ - { - "definition": { - "name": "get_weather", - "description": "Get the current weather in a given location", - "input_schema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 - }, - }, - "required": ["location"], - }, - } - } - ], - None, id="with-valid-anthropic-tools", ), pytest.param( @@ -263,8 +220,6 @@ class TestPromptMutations: }, } }, - [], - {"definition": {"foo": "bar"}}, id="with-output-schema", ), ], @@ -274,8 +229,6 @@ async def test_create_chat_prompt_succeeds_with_valid_input( db: DbSessionFactory, gql_client: AsyncGraphQLClient, variables: dict[str, Any], - expected_tools: list[dict[str, Any]], - expected_output_schema: Optional[dict[str, Any]], ) -> None: result = await gql_client.execute(self.CREATE_CHAT_PROMPT_MUTATION, variables) assert not result.errors @@ -297,7 +250,9 @@ async def test_create_chat_prompt_succeeds_with_valid_input( assert prompt_version.pop("modelProvider") == expected_model_provider assert prompt_version.pop("modelName") == expected_model_name assert prompt_version.pop("invocationParameters") == {"temperature": 0.4} + expected_tools = variables["input"]["promptVersion"].get("tools", []) assert prompt_version.pop("tools") == expected_tools + expected_output_schema = variables["input"]["promptVersion"].get("outputSchema") assert prompt_version.pop("outputSchema") == expected_output_schema assert isinstance(prompt_version.pop("id"), str) @@ -475,7 +430,7 @@ async def test_create_chat_prompt_fails_with_invalid_input( assert result.data is None @pytest.mark.parametrize( - "variables,expected_tools,expected_output_schema", + "variables", [ pytest.param( { @@ -492,8 +447,6 @@ async def test_create_chat_prompt_fails_with_invalid_input( }, } }, - [], - None, id="basic-input", ), pytest.param( @@ -512,8 +465,6 @@ async def test_create_chat_prompt_fails_with_invalid_input( }, } }, - [{"definition": {"foo": "bar"}}], - None, id="with-tools", ), pytest.param( @@ -545,21 +496,6 @@ async def test_create_chat_prompt_fails_with_invalid_input( }, } }, - [ - { - "definition": { - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - }, - }, - } - } - ], - None, id="with-valid-openai-tools", ), pytest.param( @@ -578,8 +514,6 @@ async def test_create_chat_prompt_fails_with_invalid_input( }, } }, - [], - {"definition": {"foo": "bar"}}, id="with-output-schema", ), pytest.param( @@ -620,30 +554,6 @@ async def test_create_chat_prompt_fails_with_invalid_input( }, } }, - [ - { - "definition": { - "name": "get_weather", - "description": "Get the current weather in a given location", - "input_schema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 - }, - }, - "required": ["location"], - }, - } - } - ], - None, id="with-valid-anthropic-tools", ), ], @@ -653,8 +563,6 @@ async def test_create_chat_prompt_version_succeeds_with_valid_input( db: DbSessionFactory, gql_client: AsyncGraphQLClient, variables: dict[str, Any], - expected_tools: list[dict[str, Any]], - expected_output_schema: Optional[dict[str, Any]], ) -> None: # Create initial prompt create_prompt_result = await gql_client.execute( @@ -698,7 +606,9 @@ async def test_create_chat_prompt_version_succeeds_with_valid_input( assert latest_prompt_version.pop("modelProvider") == expected_model_provider assert latest_prompt_version.pop("modelName") == expected_model_name assert latest_prompt_version.pop("invocationParameters") == {"temperature": 0.4} + expected_tools = variables["input"]["promptVersion"].get("tools", []) assert latest_prompt_version.pop("tools") == expected_tools + expected_output_schema = variables["input"]["promptVersion"].get("outputSchema") assert latest_prompt_version.pop("outputSchema") == expected_output_schema assert isinstance(latest_prompt_version.pop("id"), str) From 2a9c2483c03f30241ed3eaa9b0189a55712f2921 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 8 Jan 2025 15:04:10 -0800 Subject: [PATCH 14/15] helper function --- .../server/api/helpers/prompts/models.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index 6de4243d3a..d3fb1c55b1 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -100,25 +100,28 @@ class PromptVersion(PromptModel): @model_validator(mode="after") def validate_tool_definitions_for_known_model_providers(self) -> "PromptVersion": tool_definitions = [tool_def.definition for tool_def in self.tools.tool_definitions] - if self.model_provider.lower() == "openai": + tool_definition_model = _get_tool_definition_model(self.model_provider) + if tool_definition_model: for tool_definition_index, tool_definition in enumerate(tool_definitions): try: - OpenAIToolDefinition.model_validate(tool_definition) - except ValidationError as e: + tool_definition_model.model_validate(tool_definition) + except ValidationError as error: raise ValueError( - f"Invalid OpenAI tool definition at index {tool_definition_index}: {e}" - ) - if self.model_provider.lower() == "anthropic": - for tool_definition_index, tool_definition in enumerate(tool_definitions): - try: - AnthropicToolDefinition.model_validate(tool_definition) - except ValidationError as e: - raise ValueError( - f"Invalid Anthropic tool definition at index {tool_definition_index}: {e}" + f"Invalid tool definition at index {tool_definition_index}: {error}" ) return self +def _get_tool_definition_model( + model_provider: str, +) -> Optional[Union["OpenAIToolDefinition", "AnthropicToolDefinition"]]: + if model_provider.lower() == "openai": + return OpenAIToolDefinition + if model_provider.lower() == "anthropic": + return AnthropicToolDefinition + return None + + # JSON schema JSONSchemaPrimitiveProperty: TypeAlias = Union[ "JSONSchemaIntegerProperty", From 52a2a0ab531dd39da580c7099fb5c075545d9338 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 8 Jan 2025 15:11:23 -0800 Subject: [PATCH 15/15] fix types --- src/phoenix/server/api/helpers/prompts/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index d3fb1c55b1..d287b7f077 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -114,7 +114,7 @@ def validate_tool_definitions_for_known_model_providers(self) -> "PromptVersion" def _get_tool_definition_model( model_provider: str, -) -> Optional[Union["OpenAIToolDefinition", "AnthropicToolDefinition"]]: +) -> Optional[Union[type["OpenAIToolDefinition"], type["AnthropicToolDefinition"]]]: if model_provider.lower() == "openai": return OpenAIToolDefinition if model_provider.lower() == "anthropic":