From ef5ce5677a3a1f97b2891684c8935619a1cd7e27 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Thu, 9 Jan 2025 11:18:42 -0800 Subject: [PATCH] feat(prompts): tool call definitions (#5922) --- schemas/openapi.json | 2 +- .../server/api/helpers/prompts/models.py | 206 +++- .../server/api/mutations/prompt_mutations.py | 97 +- tests/unit/server/api/helpers/test_models.py | 962 ++++++++++++++++++ .../api/mutations/test_prompt_mutations.py | 344 ++++++- 5 files changed, 1540 insertions(+), 71 deletions(-) create mode 100644 tests/unit/server/api/helpers/test_models.py diff --git a/schemas/openapi.json b/schemas/openapi.json index e485d01171d..58d620441f4 100644 --- a/schemas/openapi.json +++ b/schemas/openapi.json @@ -1918,6 +1918,7 @@ "title": "Definition" } }, + "additionalProperties": false, "type": "object", "required": [ "definition" @@ -1993,7 +1994,6 @@ "$ref": "#/components/schemas/PromptToolDefinition" }, "type": "array", - "minItems": 1, "title": "Tool Definitions" } }, diff --git a/src/phoenix/server/api/helpers/prompts/models.py b/src/phoenix/server/api/helpers/prompts/models.py index ebaaddac024..d287b7f077f 100644 --- a/src/phoenix/server/api/helpers/prompts/models.py +++ b/src/phoenix/server/api/helpers/prompts/models.py @@ -1,12 +1,28 @@ from enum import Enum -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator from typing_extensions import TypeAlias JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]] +class Undefined: + """ + A singleton class that represents an unset or undefined value. Needed since Pydantic + can't natively distinguish between an undefined value and a value that is set to + None. + """ + + def __new__(cls) -> Any: + if not hasattr(cls, "_instance"): + cls._instance = super().__new__(cls) + return cls._instance + + +UNDEFINED: Any = Undefined() + + class PromptTemplateType(str, Enum): STRING = "STR" CHAT = "CHAT" @@ -54,7 +70,7 @@ class PromptStringTemplateV1(PromptModel): PromptTemplate: TypeAlias = Union[PromptChatTemplateV1, PromptStringTemplateV1] -class PromptJSONSchema(BaseModel): +class PromptJSONSchema(PromptModel): """A JSON schema definition used to guide an LLM's output""" definition: dict[str, Any] @@ -66,4 +82,186 @@ class PromptToolDefinition(PromptModel): class PromptToolsV1(PromptModel): version: Literal["tools-v1"] = "tools-v1" - tool_definitions: list[PromptToolDefinition] = Field(..., min_length=1) + tool_definitions: list[PromptToolDefinition] + + +class PromptVersion(PromptModel): + user_id: Optional[int] + description: Optional[str] + template_type: PromptTemplateType + template_format: PromptTemplateFormat + template: PromptTemplate + invocation_parameters: Optional[dict[str, Any]] + tools: PromptToolsV1 + output_schema: Optional[dict[str, Any]] + model_name: str + model_provider: str + + @model_validator(mode="after") + def validate_tool_definitions_for_known_model_providers(self) -> "PromptVersion": + tool_definitions = [tool_def.definition for tool_def in self.tools.tool_definitions] + tool_definition_model = _get_tool_definition_model(self.model_provider) + if tool_definition_model: + for tool_definition_index, tool_definition in enumerate(tool_definitions): + try: + tool_definition_model.model_validate(tool_definition) + except ValidationError as error: + raise ValueError( + f"Invalid tool definition at index {tool_definition_index}: {error}" + ) + return self + + +def _get_tool_definition_model( + model_provider: str, +) -> Optional[Union[type["OpenAIToolDefinition"], type["AnthropicToolDefinition"]]]: + if model_provider.lower() == "openai": + return OpenAIToolDefinition + if model_provider.lower() == "anthropic": + return AnthropicToolDefinition + return None + + +# JSON schema +JSONSchemaPrimitiveProperty: TypeAlias = Union[ + "JSONSchemaIntegerProperty", + "JSONSchemaNumberProperty", + "JSONSchemaBooleanProperty", + "JSONSchemaNullProperty", + "JSONSchemaStringProperty", +] +JSONSchemaContainerProperty: TypeAlias = Union[ + "JSONSchemaArrayProperty", + "JSONSchemaObjectProperty", +] +JSONSchemaProperty: TypeAlias = Union[ + "JSONSchemaPrimitiveProperty", + "JSONSchemaContainerProperty", +] + + +class JSONSchemaIntegerProperty(PromptModel): + type: Literal["integer"] + description: str = UNDEFINED + minimum: int = UNDEFINED + maximum: int = UNDEFINED + + @model_validator(mode="after") + def ensure_minimum_lte_maximum(self) -> "JSONSchemaIntegerProperty": + if ( + self.minimum is not UNDEFINED + and self.maximum is not UNDEFINED + and self.minimum > self.maximum + ): + raise ValueError("minimum must be less than or equal to maximum") + return self + + +class JSONSchemaNumberProperty(PromptModel): + type: Literal["number"] + description: str = UNDEFINED + minimum: float = UNDEFINED + maximum: float = UNDEFINED + + @model_validator(mode="after") + def ensure_minimum_lte_maximum(self) -> "JSONSchemaNumberProperty": + if ( + self.minimum is not UNDEFINED + and self.maximum is not UNDEFINED + and self.minimum > self.maximum + ): + raise ValueError("minimum must be less than or equal to maximum") + return self + + +class JSONSchemaBooleanProperty(PromptModel): + type: Literal["boolean"] + description: str = UNDEFINED + + +class JSONSchemaNullProperty(PromptModel): + type: Literal["null"] + description: str = UNDEFINED + + +class JSONSchemaStringProperty(PromptModel): + type: Literal["string"] + description: str = UNDEFINED + enum: list[str] = UNDEFINED + + @field_validator("enum") + def ensure_unique_enum_values(cls, enum_values: list[str]) -> list[str]: + if enum_values is UNDEFINED: + return enum_values + if len(enum_values) != len(set(enum_values)): + raise ValueError("Enum values must be unique") + return enum_values + + +class JSONSchemaArrayProperty(PromptModel): + type: Literal["array"] + description: str = UNDEFINED + items: Union[JSONSchemaProperty, "JSONSchemaAnyOf"] + + +class JSONSchemaObjectProperty(PromptModel): + type: Literal["object"] + description: str = UNDEFINED + properties: dict[str, Union[JSONSchemaProperty, "JSONSchemaAnyOf"]] + required: list[str] = UNDEFINED + additional_properties: bool = Field(UNDEFINED, alias="additionalProperties") + + @model_validator(mode="after") + def ensure_required_fields_are_included_in_properties(self) -> "JSONSchemaObjectProperty": + if self.required is UNDEFINED: + return self + invalid_fields = [field for field in self.required if field not in self.properties] + if invalid_fields: + raise ValueError(f"Required fields {invalid_fields} are not defined in properties") + return self + + +class JSONSchemaAnyOf(PromptModel): + description: str = UNDEFINED + any_of: list[JSONSchemaProperty] = Field(..., alias="anyOf") + + +# OpenAI tool definitions +class OpenAIFunctionDefinition(PromptModel): + """ + Based on https://github.com/openai/openai-python/blob/1e07c9d839e7e96f02d0a4b745f379a43086334c/src/openai/types/shared_params/function_definition.py#L13 + """ + + name: str + description: str = UNDEFINED + parameters: JSONSchemaObjectProperty = UNDEFINED + strict: Optional[bool] = UNDEFINED + + +class OpenAIToolDefinition(PromptModel): + """ + Based on https://github.com/openai/openai-python/blob/1e07c9d839e7e96f02d0a4b745f379a43086334c/src/openai/types/chat/chat_completion_tool_param.py#L12 + """ + + function: OpenAIFunctionDefinition + type: Literal["function"] + + +class AnthropicCacheControlEphemeralParam(PromptModel): + """ + Based on https://github.com/anthropics/anthropic-sdk-python/blob/93cbbbde964e244f02bf1bd2b579c5fabce4e267/src/anthropic/types/cache_control_ephemeral_param.py#L10 + """ + + type: Literal["ephemeral"] + + +# Anthropic tool definitions +class AnthropicToolDefinition(PromptModel): + """ + Based on https://github.com/anthropics/anthropic-sdk-python/blob/93cbbbde964e244f02bf1bd2b579c5fabce4e267/src/anthropic/types/tool_param.py#L22 + """ + + input_schema: JSONSchemaObjectProperty + name: str + cache_control: Optional[AnthropicCacheControlEphemeralParam] = UNDEFINED + description: str = UNDEFINED diff --git a/src/phoenix/server/api/mutations/prompt_mutations.py b/src/phoenix/server/api/mutations/prompt_mutations.py index 4c373f376e8..4f2037ae126 100644 --- a/src/phoenix/server/api/mutations/prompt_mutations.py +++ b/src/phoenix/server/api/mutations/prompt_mutations.py @@ -16,6 +16,7 @@ PromptChatTemplateV1, PromptJSONSchema, PromptToolsV1, + PromptVersion, ) from phoenix.server.api.input_types.PromptVersionInput import ChatPromptVersionInput from phoenix.server.api.queries import Query @@ -48,16 +49,18 @@ class PromptMutationMixin: async def create_chat_prompt( self, info: Info[Context, None], input: CreateChatPromptInput ) -> Prompt: + user_id: Optional[int] = None + assert isinstance(request := info.context.request, Request) + if "user" in request.scope: + assert isinstance(user := request.user, PhoenixUser) + user_id = int(user.identity) + try: tool_definitions = [] for tool in input.prompt_version.tools: pydantic_tool = tool.to_pydantic() - tool_definitions.append(pydantic_tool.dict()) - tools = ( - PromptToolsV1(tool_definitions=tool_definitions).dict() - if tool_definitions - else None - ) + tool_definitions.append(pydantic_tool) + tools = PromptToolsV1(tool_definitions=tool_definitions) output_schema = ( PromptJSONSchema.model_validate( strawberry.asdict(input.prompt_version.output_schema) @@ -68,26 +71,33 @@ async def create_chat_prompt( template = PromptChatTemplateV1.model_validate( strawberry.asdict(input.prompt_version.template) ).dict() - except ValidationError as error: - raise BadRequest(str(error)) - - user_id: Optional[int] = None - assert isinstance(request := info.context.request, Request) - if "user" in request.scope: - assert isinstance(user := request.user, PhoenixUser) - user_id = int(user.identity) - async with info.context.db() as session: - prompt_version = models.PromptVersion( - description=input.prompt_version.description, + pydantic_prompt_version = PromptVersion( user_id=user_id, + description=input.prompt_version.description, template_type=input.prompt_version.template_type.value, template_format=input.prompt_version.template_format.value, template=template, invocation_parameters=input.prompt_version.invocation_parameters, tools=tools, output_schema=output_schema, - model_provider=input.prompt_version.model_provider, model_name=input.prompt_version.model_name, + model_provider=input.prompt_version.model_provider, + ) + except ValidationError as error: + raise BadRequest(str(error)) + + async with info.context.db() as session: + prompt_version = models.PromptVersion( + description=pydantic_prompt_version.description, + user_id=pydantic_prompt_version.user_id, + template_type=pydantic_prompt_version.template_type, + template_format=pydantic_prompt_version.template_format, + template=pydantic_prompt_version.template.dict(), + invocation_parameters=pydantic_prompt_version.invocation_parameters, + tools=pydantic_prompt_version.tools.dict(), + output_schema=pydantic_prompt_version.output_schema, + model_provider=pydantic_prompt_version.model_provider, + model_name=pydantic_prompt_version.model_name, ) prompt = models.Prompt( name=input.name, @@ -107,16 +117,18 @@ async def create_chat_prompt_version( info: Info[Context, None], input: CreateChatPromptVersionInput, ) -> Prompt: + user_id: Optional[int] = None + assert isinstance(request := info.context.request, Request) + if "user" in request.scope: + assert isinstance(user := request.user, PhoenixUser) + user_id = int(user.identity) + try: tool_definitions = [] for tool in input.prompt_version.tools: pydantic_tool = tool.to_pydantic() - tool_definitions.append(pydantic_tool.dict()) - tools = ( - PromptToolsV1(tool_definitions=tool_definitions).dict() - if tool_definitions - else None - ) + tool_definitions.append(pydantic_tool) + tools = PromptToolsV1(tool_definitions=tool_definitions) output_schema = ( PromptJSONSchema.model_validate( strawberry.asdict(input.prompt_version.output_schema) @@ -127,14 +139,21 @@ async def create_chat_prompt_version( template = PromptChatTemplateV1.model_validate( strawberry.asdict(input.prompt_version.template) ).dict() + pydantic_prompt_version = PromptVersion( + user_id=user_id, + description=input.prompt_version.description, + template_type=input.prompt_version.template_type.value, + template_format=input.prompt_version.template_format.value, + template=template, + invocation_parameters=input.prompt_version.invocation_parameters, + tools=tools, + output_schema=output_schema, + model_name=input.prompt_version.model_name, + model_provider=input.prompt_version.model_provider, + ) except ValidationError as error: raise BadRequest(str(error)) - user_id: Optional[int] = None - assert isinstance(request := info.context.request, Request) - if "user" in request.scope: - assert isinstance(user := request.user, PhoenixUser) - user_id = int(user.identity) prompt_id = from_global_id_with_expected_type( global_id=input.prompt_id, expected_type_name=Prompt.__name__ ) @@ -145,16 +164,16 @@ async def create_chat_prompt_version( prompt_version = models.PromptVersion( prompt_id=prompt_id, - user_id=user_id, - description=input.prompt_version.description, - template_type=input.prompt_version.template_type.value, - template_format=input.prompt_version.template_format.value, - template=template, - invocation_parameters=input.prompt_version.invocation_parameters, - tools=tools, - output_schema=output_schema, - model_provider=input.prompt_version.model_provider, - model_name=input.prompt_version.model_name, + user_id=pydantic_prompt_version.user_id, + description=pydantic_prompt_version.description, + template_type=pydantic_prompt_version.template_type, + template_format=pydantic_prompt_version.template_format, + template=pydantic_prompt_version.template.dict(), + invocation_parameters=pydantic_prompt_version.invocation_parameters, + tools=pydantic_prompt_version.tools.dict(), + output_schema=pydantic_prompt_version.output_schema, + model_provider=pydantic_prompt_version.model_provider, + model_name=pydantic_prompt_version.model_name, ) session.add(prompt_version) diff --git a/tests/unit/server/api/helpers/test_models.py b/tests/unit/server/api/helpers/test_models.py new file mode 100644 index 00000000000..d253e84f987 --- /dev/null +++ b/tests/unit/server/api/helpers/test_models.py @@ -0,0 +1,962 @@ +from typing import Any + +import pytest +from pydantic import ValidationError + +from phoenix.server.api.helpers.prompts.models import AnthropicToolDefinition, OpenAIToolDefinition + + +@pytest.mark.parametrize( + "tool_definition", + [ + pytest.param( + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + }, + id="get-weather-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_delivery_date", + "description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'", # noqa: E501 + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The customer's order ID.", + } + }, + "required": ["order_id"], + "additionalProperties": False, + }, + }, + }, + id="get-delivery-date-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "generate_recipe", + "description": "Generate a recipe based on the user's input", + "parameters": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title of the recipe.", + }, + "ingredients": { + "type": "array", + "items": {"type": "string"}, + "description": "List of ingredients required for the recipe.", + }, + "instructions": { + "type": "array", + "items": {"type": "string"}, + "description": "Step-by-step instructions for the recipe.", + }, + }, + "required": ["title", "ingredients", "instructions"], + "additionalProperties": False, + }, + }, + }, + id="generate-recipe-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_product_recommendations", + "description": "Searches for products matching certain criteria in the database", # noqa: E501 + "parameters": { + "type": "object", + "properties": { + "categories": { + "description": "categories that could be a match", + "type": "array", + "items": { + "type": "string", + "enum": [ + "coats & jackets", + "accessories", + "tops", + "jeans & trousers", + "skirts & dresses", + "shoes", + ], + }, + }, + "colors": { + "description": "colors that could be a match, empty array if N/A", + "type": "array", + "items": { + "type": "string", + "enum": [ + "black", + "white", + "brown", + "red", + "blue", + "green", + "orange", + "yellow", + "pink", + "gold", + "silver", + ], + }, + }, + "keywords": { + "description": "keywords that should be present in the item title or description", # noqa: E501 + "type": "array", + "items": {"type": "string"}, + }, + "price_range": { + "type": "object", + "properties": { + "min": {"type": "number"}, + "max": {"type": "number"}, + }, + "required": ["min", "max"], + "additionalProperties": False, + }, + "limit": { + "type": "integer", + "description": "The maximum number of products to return, use 5 by default if nothing is specified by the user", # noqa: E501 + }, + }, + "required": ["categories", "colors", "keywords", "price_range", "limit"], + "additionalProperties": False, + }, + }, + }, + id="get-product-recommendations-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_product_details", + "description": "Fetches more details about a product", + "parameters": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": "The ID of the product to fetch details for", + } + }, + "required": ["product_id"], + "additionalProperties": False, + }, + }, + }, + id="get-product-details-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "add_to_cart", + "description": "Add items to cart when the user has confirmed their interest.", + "parameters": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": "ID of the product to add to the cart", + }, + "quantity": { + "type": "integer", + "description": "Quantity of the product to add to the cart", # noqa: E501 + }, + }, + "required": ["product_id", "quantity"], + "additionalProperties": False, + }, + } + }, + "required": ["items"], + "additionalProperties": False, + }, + }, + }, + id="add-to-cart-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_order_details", + "description": "Fetches details about a specific order", + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The ID of the order to fetch details for", + } + }, + "required": ["order_id"], + "additionalProperties": False, + }, + }, + }, + id="get-order-details-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_user_orders", + "description": "Fetches the last orders for a given user", + "parameters": { + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "The ID of the user to fetch orders for", + }, + "limit": { + "type": "integer", + "description": "The maximum number of orders to return, use 5 by default and increase the number if the relevant order is not found.", # noqa: E501 + }, + }, + "required": ["user_id", "limit"], + "additionalProperties": False, + }, + }, + }, + id="get-user-orders-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "search_faq", + "description": "Searches the FAQ for an answer to the user's question", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The question to search the FAQ for", + } + }, + "required": ["query"], + "additionalProperties": False, + }, + }, + }, + id="search-faq-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "process_return", + "description": "Processes a return and creates a return label", + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The ID of the order to process a return for", + }, + "items": { + "type": "array", + "description": "The items to return", + "items": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": "The ID of the product to return", + }, + "quantity": { + "type": "integer", + "description": "The quantity of the product to return", + }, + }, + "required": ["product_id", "quantity"], + "additionalProperties": False, + }, + }, + }, + "required": ["order_id", "items"], + "additionalProperties": False, + }, + }, + }, + id="process-return-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_return_status", + "description": "Finds the status of a return", + "parameters": { + "type": "object", + "properties": { + "order_id": { + "type": "string", + "description": "The ID of the order to fetch the return status for", + } + }, + "required": ["order_id"], + "additionalProperties": False, + }, + }, + }, + id="get-return-status-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_recommendations", + "description": "Fetches recommendations based on the user's preferences", + "parameters": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type of place to search recommendations for", + "enum": ["restaurant", "hotel"], + }, + "keywords": { + "type": "array", + "description": "Keywords that should be present in the recommendations", # noqa: E501 + "items": {"type": "string"}, + }, + "location": { + "type": "string", + "description": "The location to search recommendations for", + }, + }, + "required": ["type", "keywords", "location"], + "additionalProperties": False, + }, + }, + }, + id="get-recommendations-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "show_on_map", + "description": "Places pins on the map for relevant locations", + "parameters": { + "type": "object", + "properties": { + "pins": { + "type": "array", + "description": "The pins to place on the map", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the place", + }, + "coordinates": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + }, + "required": ["name", "coordinates"], + "additionalProperties": False, + }, + }, + }, + "required": ["pins"], + "additionalProperties": False, + }, + }, + }, + id="show-on-map-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "fetch_availability", + "description": "Fetches the availability for a given place", + "parameters": { + "type": "object", + "properties": { + "place_id": { + "type": "string", + "description": "The ID of the place to fetch availability for", + } + }, + }, + }, + }, + id="fetch-availability-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "create_booking", + "description": "Creates a booking on the user's behalf", + "parameters": { + "type": "object", + "properties": { + "place_id": { + "type": "string", + "description": "The ID of the place to create a booking for", + }, + "booking_details": { + "anyOf": [ + { + "type": "object", + "description": "Restaurant booking with specific date and time", # noqa: E501 + "properties": { + "date": { + "type": "string", + "description": "The date of the booking, in format YYYY-MM-DD", # noqa: E501 + }, + "time": { + "type": "string", + "description": "The time of the booking, in format HH:MM", # noqa: E501 + }, + }, + "required": ["date", "time"], + }, + { + "type": "object", + "description": "Hotel booking with specific check-in and check-out dates", # noqa: E501 + "properties": { + "check_in": { + "type": "string", + "description": "The check-in date of the booking, in format YYYY-MM-DD", # noqa: E501 + }, + "check_out": { + "type": "string", + "description": "The check-out date of the booking, in format YYYY-MM-DD", # noqa: E501 + }, + }, + "required": ["check_in", "check_out"], + }, + ], + }, + }, + "required": ["place_id", "booking_details"], + "additionalProperties": False, + }, + }, + }, + id="create-booking-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "pick_tshirt_size", + "description": "Call this if the user specifies which size t-shirt they want", + "parameters": { + "type": "object", + "properties": { + "size": { + "type": "string", + "enum": ["s", "m", "l"], + "description": "The size of the t-shirt that the user would like to order", # noqa: E501 + } + }, + "required": ["size"], + "additionalProperties": False, + }, + }, + }, + id="pick-tshirt-size-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "test_primitives", + "description": "Test all primitive types", + "parameters": { + "type": "object", + "properties": { + "string_field": {"type": "string", "description": "A string field"}, + "number_field": {"type": "number", "description": "A number field"}, + "integer_field": {"type": "integer", "description": "An integer field"}, + "boolean_field": {"type": "boolean", "description": "A boolean field"}, + "null_field": {"type": "null", "description": "A null field"}, + }, + "required": [ + "string_field", + "number_field", + "integer_field", + "boolean_field", + "null_field", + ], + "additionalProperties": False, + }, + }, + }, + id="primitive-types-function", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "update_user_profile", + "description": "Updates a user's profile information", + "parameters": { + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "The ID of the user to update", + }, + "nickname": { + "description": "Optional nickname that can be null or a string", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + }, + "required": ["user_id"], + "additionalProperties": False, + }, + }, + }, + id="optional-anyof-parameter", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "categorize_colors", + "description": "Categorize colors into warm, cool, or neutral tones, with null for uncertain cases", # noqa: E501 + "parameters": { + "type": "object", + "properties": { + "colors": { + "type": "array", + "description": "List of color categories, with null for uncertain colors", # noqa: E501 + "items": { + "anyOf": [ + { + "type": "string", + "enum": ["warm", "cool", "neutral"], + "description": "Color category", + }, + {"type": "null"}, + ] + }, + } + }, + "required": ["colors"], + "additionalProperties": False, + }, + }, + }, + id="array-of-optional-enums", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Set temperature within valid range", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "integer", + "minimum": 0, + "maximum": 100, + "description": "Temperature in Fahrenheit (0-100)", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="integer-min-max-constraints", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Set temperature within valid range", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "number", + "minimum": 0.5, # float min + "maximum": 100, # integer max + "description": "Temperature in Fahrenheit (0-100)", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="number-min-max-constraints", + ), + ], +) +def test_openai_tool_definition_passes_valid_tool_schemas(tool_definition: dict[str, Any]) -> None: + OpenAIToolDefinition.model_validate(tool_definition) + + +@pytest.mark.parametrize( + "tool_definition", + [ + pytest.param( + { + "type": "function", + "function": { + "name": "pick_tshirt_size", + "description": "Call this if the user specifies which size t-shirt they want", + "parameters": { + "type": "object", + "properties": { + "size": { + "type": "invalid_type", + "enum": ["s", "m", "l"], + "description": "The size of the t-shirt that the user would like to order", # noqa: E501 + } + }, + "required": ["size"], + "additionalProperties": False, + }, + }, + }, + id="invalid-data-type", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Sets the temperature for the thermostat", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "number", + "enum": ["70", "72", "74"], # only string properties can have enums + "description": "The temperature to set in Fahrenheit", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="number-property-with-invalid-enum", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "extra": "extra", # extra properties are not allowed + }, + }, + }, + id="extra-properties", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "update_user", + "description": "Updates user information", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": [ + "name", + "email", # email is not in properties + ], + "additionalProperties": False, + }, + }, + }, + id="required-field-not-in-properties", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_preferences", + "parameters": { + "type": "object", + "properties": { + "priority": { + "type": "string", + "enum": [ + 0, # integer enum values not allowed + "low", + "medium", + "high", + ], + "description": "The priority level to set", + } + }, + "required": ["priority"], + "additionalProperties": False, + }, + }, + }, + id="string-property-with-priority-enum", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "select_color", + "description": "Select a color from the available options", + "parameters": { + "type": "object", + "properties": { + "color": { + "type": "string", + "enum": [ + "red", + "blue", + "red", # duplicate enum value + ], + "description": "The color to select", + } + }, + "required": ["color"], + "additionalProperties": False, + }, + }, + }, + id="duplicate-enum-values", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Set temperature with invalid range", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "integer", + "minimum": 100, + "maximum": 0, # min > max + "description": "Temperature in Celsius", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="integer-min-max-range", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_count", + "description": "Set an integer count with float bounds", + "parameters": { + "type": "object", + "properties": { + "count": { + "type": "integer", + "minimum": 1.4, # float not allowed for integer property + "description": "Count value", + } + }, + "required": ["count"], + "additionalProperties": False, + }, + }, + }, + id="integer-float-bounds", + ), + pytest.param( + { + "type": "function", + "function": { + "name": "set_temperature", + "description": "Set temperature with invalid range", + "parameters": { + "type": "object", + "properties": { + "temp": { + "type": "number", + "minimum": 100, + "maximum": 0, # min > max + "description": "Temperature in Celsius", + } + }, + "required": ["temp"], + "additionalProperties": False, + }, + }, + }, + id="number-min-max-range", + ), + ], +) +def test_openai_tool_definition_fails_invalid_tool_schemas(tool_definition: dict[str, Any]) -> None: + with pytest.raises(ValidationError): + OpenAIToolDefinition.model_validate(tool_definition) + + +@pytest.mark.parametrize( + "tool_definition", + [ + pytest.param( + { + "name": "get_weather", + "description": "Get the current weather in a given location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + }, + id="get-weather-function", + ), + pytest.param( + { + "name": "get_time", + "description": "Get the current time in a given time zone", + "input_schema": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The IANA time zone name, e.g. America/Los_Angeles", + } + }, + "required": ["timezone"], + }, + }, + id="get-time-function", + ), + pytest.param( + { + "name": "get_location", + "description": "Get the current user location based on their IP address. This tool has no parameters or arguments.", # noqa: E501 + "input_schema": {"type": "object", "properties": {}}, + }, + id="get-location-function", + ), + pytest.param( + { + "name": "record_summary", + "description": "Record summary of an image using well-structured JSON.", + "input_schema": { + "type": "object", + "properties": { + "key_colors": { + "type": "array", + "items": { + "type": "object", + "properties": { + "r": {"type": "number", "description": "red value [0.0, 1.0]"}, + "g": { + "type": "number", + "description": "green value [0.0, 1.0]", + }, + "b": {"type": "number", "description": "blue value [0.0, 1.0]"}, + "name": { + "type": "string", + "description": 'Human-readable color name in snake_case, e.g. "olive_green" or "turquoise"', # noqa: E501 + }, + }, + "required": ["r", "g", "b", "name"], + }, + "description": "Key colors in the image. Limit to less then four.", + }, + "description": { + "type": "string", + "description": "Image description. One to two sentences max.", + }, + "estimated_year": { + "type": "integer", + "description": "Estimated year that the images was taken, if is it a photo. Only set this if the image appears to be non-fictional. Rough estimates are okay!", # noqa: E501 + }, + }, + "required": ["key_colors", "description"], + }, + }, + id="record-image-summary", + ), + pytest.param( + { + "name": "get_weather", + "input_schema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + "cache_control": {"type": "ephemeral"}, + }, + id="get-weather-function-cache-control-ephemeral", + ), + pytest.param( + { + "name": "get_weather", + "input_schema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + "cache_control": None, + }, + id="get-weather-function-cache-control-none", + ), + ], +) +def test_anthropic_tool_definition_passes_valid_tool_schemas( + tool_definition: dict[str, Any], +) -> None: + AnthropicToolDefinition.model_validate(tool_definition) diff --git a/tests/unit/server/api/mutations/test_prompt_mutations.py b/tests/unit/server/api/mutations/test_prompt_mutations.py index eb30d7274a1..c0b801dc380 100644 --- a/tests/unit/server/api/mutations/test_prompt_mutations.py +++ b/tests/unit/server/api/mutations/test_prompt_mutations.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import pytest from strawberry.relay.types import GlobalID @@ -91,7 +91,7 @@ class TestPromptMutations: """ @pytest.mark.parametrize( - "variables,expected_tools,expected_output_schema", + "variables", [ pytest.param( { @@ -109,8 +109,6 @@ class TestPromptMutations: }, } }, - [], - None, id="basic-input", ), pytest.param( @@ -124,16 +122,87 @@ class TestPromptMutations: "templateFormat": "MUSTACHE", "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, - "modelProvider": "openai", - "modelName": "o1-mini", + "modelProvider": "unknown", + "modelName": "unknown", "tools": [{"definition": {"foo": "bar"}}], }, } }, - [{"definition": {"foo": "bar"}}], - None, id="with-tools", ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + }, + } + }, + id="with-valid-openai-tools", + ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "anthropic", + "modelName": "claude-2", + "tools": [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", # noqa: E501 + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", # noqa: E501 + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + } + } + ], + }, + } + }, + id="with-valid-anthropic-tools", + ), pytest.param( { "input": { @@ -151,8 +220,6 @@ class TestPromptMutations: }, } }, - [], - {"definition": {"foo": "bar"}}, id="with-output-schema", ), ], @@ -162,8 +229,6 @@ async def test_create_chat_prompt_succeeds_with_valid_input( db: DbSessionFactory, gql_client: AsyncGraphQLClient, variables: dict[str, Any], - expected_tools: list[dict[str, Any]], - expected_output_schema: Optional[dict[str, Any]], ) -> None: result = await gql_client.execute(self.CREATE_CHAT_PROMPT_MUTATION, variables) assert not result.errors @@ -180,10 +245,14 @@ async def test_create_chat_prompt_succeeds_with_valid_input( assert prompt_version.pop("user") is None assert prompt_version.pop("templateType") == "CHAT" assert prompt_version.pop("templateFormat") == "MUSTACHE" - assert prompt_version.pop("modelProvider") == "openai" - assert prompt_version.pop("modelName") == "o1-mini" + expected_model_provider = variables["input"]["promptVersion"]["modelProvider"] + expected_model_name = variables["input"]["promptVersion"]["modelName"] + assert prompt_version.pop("modelProvider") == expected_model_provider + assert prompt_version.pop("modelName") == expected_model_name assert prompt_version.pop("invocationParameters") == {"temperature": 0.4} + expected_tools = variables["input"]["promptVersion"].get("tools", []) assert prompt_version.pop("tools") == expected_tools + expected_output_schema = variables["input"]["promptVersion"].get("outputSchema") assert prompt_version.pop("outputSchema") == expected_output_schema assert isinstance(prompt_version.pop("id"), str) @@ -272,6 +341,84 @@ async def test_create_chat_prompt_fails_on_name_conflict( "Input should be a valid dictionary", id="invalid-output-schema", ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "invalid_type", # invalid schema type + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + }, + } + }, + "function.parameters.type", + id="with-invalid-openai-tools", + ), + pytest.param( + { + "input": { + "name": "prompt-name", + "description": "prompt-description", + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "anthropic", + "modelName": "claude-2", + "tools": [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", # noqa: E501 + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", # noqa: E501 + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + "cache_control": { + "type": "invalid_type" + }, # invalid cache control type + } + } + ], + }, + } + }, + "cache_control.type", + id="with-invalid-anthropic-tools", + ), ], ) async def test_create_chat_prompt_fails_with_invalid_input( @@ -283,7 +430,7 @@ async def test_create_chat_prompt_fails_with_invalid_input( assert result.data is None @pytest.mark.parametrize( - "variables,expected_tools,expected_output_schema", + "variables", [ pytest.param( { @@ -300,8 +447,6 @@ async def test_create_chat_prompt_fails_with_invalid_input( }, } }, - [], - None, id="basic-input", ), pytest.param( @@ -314,16 +459,45 @@ async def test_create_chat_prompt_fails_with_invalid_input( "templateFormat": "MUSTACHE", "template": {"messages": [{"role": "USER", "content": "hello world"}]}, "invocationParameters": {"temperature": 0.4}, - "modelProvider": "openai", - "modelName": "o1-mini", + "modelProvider": "unknown", + "modelName": "unknown", "tools": [{"definition": {"foo": "bar"}}], }, } }, - [{"definition": {"foo": "bar"}}], - None, id="with-tools", ), + pytest.param( + { + "input": { + "promptId": str(GlobalID("Prompt", "1")), + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + }, + } + }, + id="with-valid-openai-tools", + ), pytest.param( { "input": { @@ -340,10 +514,48 @@ async def test_create_chat_prompt_fails_with_invalid_input( }, } }, - [], - {"definition": {"foo": "bar"}}, id="with-output-schema", ), + pytest.param( + { + "input": { + "promptId": str(GlobalID("Prompt", "1")), + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "anthropic", + "modelName": "claude-2", + "tools": [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", # noqa: E501 + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", # noqa: E501 + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + } + } + ], + }, + } + }, + id="with-valid-anthropic-tools", + ), ], ) async def test_create_chat_prompt_version_succeeds_with_valid_input( @@ -351,8 +563,6 @@ async def test_create_chat_prompt_version_succeeds_with_valid_input( db: DbSessionFactory, gql_client: AsyncGraphQLClient, variables: dict[str, Any], - expected_tools: list[dict[str, Any]], - expected_output_schema: Optional[dict[str, Any]], ) -> None: # Create initial prompt create_prompt_result = await gql_client.execute( @@ -391,10 +601,14 @@ async def test_create_chat_prompt_version_succeeds_with_valid_input( assert latest_prompt_version.pop("user") is None assert latest_prompt_version.pop("templateType") == "CHAT" assert latest_prompt_version.pop("templateFormat") == "MUSTACHE" - assert latest_prompt_version.pop("modelProvider") == "openai" - assert latest_prompt_version.pop("modelName") == "o1-mini" + expected_model_provider = variables["input"]["promptVersion"]["modelProvider"] + expected_model_name = variables["input"]["promptVersion"]["modelName"] + assert latest_prompt_version.pop("modelProvider") == expected_model_provider + assert latest_prompt_version.pop("modelName") == expected_model_name assert latest_prompt_version.pop("invocationParameters") == {"temperature": 0.4} + expected_tools = variables["input"]["promptVersion"].get("tools", []) assert latest_prompt_version.pop("tools") == expected_tools + expected_output_schema = variables["input"]["promptVersion"].get("outputSchema") assert latest_prompt_version.pop("outputSchema") == expected_output_schema assert isinstance(latest_prompt_version.pop("id"), str) @@ -476,6 +690,82 @@ async def test_create_chat_prompt_version_fails_with_nonexistent_prompt_id( "Input should be a valid dictionary", id="invalid-output-schema", ), + pytest.param( + { + "input": { + "promptId": str(GlobalID("Prompt", "1")), + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "openai", + "modelName": "o1-mini", + "tools": [ + { + "definition": { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "invalid_type", # invalid schema type + "properties": {"location": {"type": "string"}}, + }, + }, + } + } + ], + }, + } + }, + "function.parameters.type", + id="with-invalid-openai-tools", + ), + pytest.param( + { + "input": { + "promptId": str(GlobalID("Prompt", "1")), + "promptVersion": { + "description": "prompt-version-description", + "templateType": "CHAT", + "templateFormat": "MUSTACHE", + "template": {"messages": [{"role": "USER", "content": "hello world"}]}, + "invocationParameters": {"temperature": 0.4}, + "modelProvider": "anthropic", + "modelName": "claude-2", + "tools": [ + { + "definition": { + "name": "get_weather", + "description": "Get the current weather in a given location", # noqa: E501 + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", # noqa: E501 + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": 'The unit of temperature, either "celsius" or "fahrenheit"', # noqa: E501 + }, + }, + "required": ["location"], + }, + "cache_control": { + "type": "invalid_type" + }, # invalid cache control type + } + } + ], + }, + } + }, + "cache_control.type", + id="with-invalid-anthropic-tools", + ), ], ) async def test_create_chat_prompt_version_fails_with_invalid_input(