Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prompts): tool call definitions #5922

Merged
merged 15 commits into from
Jan 9, 2025
2 changes: 1 addition & 1 deletion schemas/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,7 @@
"title": "Definition"
}
},
"additionalProperties": false,
"type": "object",
"required": [
"definition"
Expand Down Expand Up @@ -1993,7 +1994,6 @@
"$ref": "#/components/schemas/PromptToolDefinition"
},
"type": "array",
"minItems": 1,
"title": "Tool Definitions"
}
},
Expand Down
206 changes: 202 additions & 4 deletions src/phoenix/server/api/helpers/prompts/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
from enum import Enum
from typing import Any, Literal, Union
from typing import Any, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
from typing_extensions import TypeAlias

JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]


class Undefined:
"""
A singleton class that represents an unset or undefined value. Needed since Pydantic
can't natively distinguish between an undefined value and a value that is set to
None.
"""

def __new__(cls) -> Any:
if not hasattr(cls, "_instance"):
cls._instance = super().__new__(cls)
return cls._instance


UNDEFINED: Any = Undefined()


class PromptTemplateType(str, Enum):
STRING = "STR"
CHAT = "CHAT"
Expand Down Expand Up @@ -54,7 +70,7 @@ class PromptStringTemplateV1(PromptModel):
PromptTemplate: TypeAlias = Union[PromptChatTemplateV1, PromptStringTemplateV1]


class PromptJSONSchema(BaseModel):
class PromptJSONSchema(PromptModel):
"""A JSON schema definition used to guide an LLM's output"""

definition: dict[str, Any]
Expand All @@ -66,4 +82,186 @@ class PromptToolDefinition(PromptModel):

class PromptToolsV1(PromptModel):
version: Literal["tools-v1"] = "tools-v1"
tool_definitions: list[PromptToolDefinition] = Field(..., min_length=1)
tool_definitions: list[PromptToolDefinition]


class PromptVersion(PromptModel):
user_id: Optional[int]
description: Optional[str]
template_type: PromptTemplateType
template_format: PromptTemplateFormat
template: PromptTemplate
invocation_parameters: Optional[dict[str, Any]]
tools: PromptToolsV1
output_schema: Optional[dict[str, Any]]
model_name: str
model_provider: str

@model_validator(mode="after")
def validate_tool_definitions_for_known_model_providers(self) -> "PromptVersion":
tool_definitions = [tool_def.definition for tool_def in self.tools.tool_definitions]
tool_definition_model = _get_tool_definition_model(self.model_provider)
if tool_definition_model:
for tool_definition_index, tool_definition in enumerate(tool_definitions):
try:
tool_definition_model.model_validate(tool_definition)
except ValidationError as error:
raise ValueError(
f"Invalid tool definition at index {tool_definition_index}: {error}"
)
return self


def _get_tool_definition_model(
model_provider: str,
) -> Optional[Union[type["OpenAIToolDefinition"], type["AnthropicToolDefinition"]]]:
if model_provider.lower() == "openai":
return OpenAIToolDefinition
if model_provider.lower() == "anthropic":
return AnthropicToolDefinition
return None


# JSON schema
JSONSchemaPrimitiveProperty: TypeAlias = Union[
"JSONSchemaIntegerProperty",
"JSONSchemaNumberProperty",
"JSONSchemaBooleanProperty",
"JSONSchemaNullProperty",
"JSONSchemaStringProperty",
]
JSONSchemaContainerProperty: TypeAlias = Union[
"JSONSchemaArrayProperty",
"JSONSchemaObjectProperty",
]
JSONSchemaProperty: TypeAlias = Union[
"JSONSchemaPrimitiveProperty",
"JSONSchemaContainerProperty",
]


class JSONSchemaIntegerProperty(PromptModel):
type: Literal["integer"]
description: str = UNDEFINED
minimum: int = UNDEFINED
maximum: int = UNDEFINED

@model_validator(mode="after")
def ensure_minimum_lte_maximum(self) -> "JSONSchemaIntegerProperty":
if (
self.minimum is not UNDEFINED
and self.maximum is not UNDEFINED
and self.minimum > self.maximum
):
raise ValueError("minimum must be less than or equal to maximum")
return self


class JSONSchemaNumberProperty(PromptModel):
type: Literal["number"]
description: str = UNDEFINED
minimum: float = UNDEFINED
maximum: float = UNDEFINED

@model_validator(mode="after")
def ensure_minimum_lte_maximum(self) -> "JSONSchemaNumberProperty":
if (
self.minimum is not UNDEFINED
and self.maximum is not UNDEFINED
and self.minimum > self.maximum
):
raise ValueError("minimum must be less than or equal to maximum")
return self


class JSONSchemaBooleanProperty(PromptModel):
type: Literal["boolean"]
description: str = UNDEFINED


class JSONSchemaNullProperty(PromptModel):
type: Literal["null"]
description: str = UNDEFINED


class JSONSchemaStringProperty(PromptModel):
type: Literal["string"]
description: str = UNDEFINED
enum: list[str] = UNDEFINED

@field_validator("enum")
def ensure_unique_enum_values(cls, enum_values: list[str]) -> list[str]:
if enum_values is UNDEFINED:
return enum_values
if len(enum_values) != len(set(enum_values)):
raise ValueError("Enum values must be unique")
return enum_values


class JSONSchemaArrayProperty(PromptModel):
type: Literal["array"]
description: str = UNDEFINED
items: Union[JSONSchemaProperty, "JSONSchemaAnyOf"]


class JSONSchemaObjectProperty(PromptModel):
type: Literal["object"]
description: str = UNDEFINED
properties: dict[str, Union[JSONSchemaProperty, "JSONSchemaAnyOf"]]
required: list[str] = UNDEFINED
additional_properties: bool = Field(UNDEFINED, alias="additionalProperties")

@model_validator(mode="after")
def ensure_required_fields_are_included_in_properties(self) -> "JSONSchemaObjectProperty":
if self.required is UNDEFINED:
return self
invalid_fields = [field for field in self.required if field not in self.properties]
if invalid_fields:
raise ValueError(f"Required fields {invalid_fields} are not defined in properties")
return self


class JSONSchemaAnyOf(PromptModel):
description: str = UNDEFINED
any_of: list[JSONSchemaProperty] = Field(..., alias="anyOf")


# OpenAI tool definitions
class OpenAIFunctionDefinition(PromptModel):
"""
Based on https://github.com/openai/openai-python/blob/1e07c9d839e7e96f02d0a4b745f379a43086334c/src/openai/types/shared_params/function_definition.py#L13
"""

name: str
description: str = UNDEFINED
parameters: JSONSchemaObjectProperty = UNDEFINED
strict: Optional[bool] = UNDEFINED


class OpenAIToolDefinition(PromptModel):
"""
Based on https://github.com/openai/openai-python/blob/1e07c9d839e7e96f02d0a4b745f379a43086334c/src/openai/types/chat/chat_completion_tool_param.py#L12
"""

function: OpenAIFunctionDefinition
type: Literal["function"]


class AnthropicCacheControlEphemeralParam(PromptModel):
"""
Based on https://github.com/anthropics/anthropic-sdk-python/blob/93cbbbde964e244f02bf1bd2b579c5fabce4e267/src/anthropic/types/cache_control_ephemeral_param.py#L10
"""

type: Literal["ephemeral"]


# Anthropic tool definitions
class AnthropicToolDefinition(PromptModel):
"""
Based on https://github.com/anthropics/anthropic-sdk-python/blob/93cbbbde964e244f02bf1bd2b579c5fabce4e267/src/anthropic/types/tool_param.py#L22
"""

input_schema: JSONSchemaObjectProperty
name: str
cache_control: Optional[AnthropicCacheControlEphemeralParam] = UNDEFINED
description: str = UNDEFINED
Loading
Loading