diff --git a/tests/test_schema.py b/tests/test_schema.py index 19744e5..31ede1c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -8,14 +8,17 @@ from pydantic import BaseModel from wanga.schema.extractor import default_schema_extractor +from wanga.schema.jsonschema import JsonSchemaFlavour from wanga.schema.normalize import normalize_annotation, unpack_optional from wanga.schema.schema import ( CallableSchema, + LiteralNode, ObjectField, ObjectNode, PrimitiveNode, SequenceNode, UndefinedNode, + UnionNode, ) @@ -28,6 +31,14 @@ def test_normalize_schema(): typing.Union[int, float]: int | float, typing.Optional[int]: int | None, typing.List: list, + typing.Union[typing.Union[int, float], str]: int | float | str, + (typing.Literal[1] | typing.Literal[2] | typing.Literal[3]): typing.Literal[ + 1, 2, 3 + ], + ( + typing.Literal[1, 2] + | typing.Union[typing.Literal[2, 3], typing.Literal[3, 4]] + ): (typing.Literal[1, 2, 3, 4]), } for annotation, result in expected.items(): assert normalize_annotation(annotation) == result @@ -55,7 +66,7 @@ def test_concretize_schema(): def test_extract_schema(): - def foo(x: int, y: str = "hello"): # noqa + def foo(x: int, y: str = "hello", z: tuple[int, ...] = ()): # noqa pass foo_schema = CallableSchema( @@ -77,19 +88,30 @@ def foo(x: int, y: str = "hello"): # noqa required=False, hint=None, ), + ObjectField( + name="z", + schema=SequenceNode( + sequence_type=tuple, + item_schema=PrimitiveNode(primitive_type=int), + ), + required=False, + hint=None, + ), ], ), + long_description=None, ) assert default_schema_extractor.extract_schema(foo) == foo_schema - def bar(x: typing.List[int]) -> int: # noqa + def bar(x: typing.List[int], y: typing.Literal["hehe"] | float) -> int: # noqa r"""Bar. Blah blah blah. Args: x: The x. + y: Hard example. """ return 0 @@ -109,8 +131,20 @@ def bar(x: typing.List[int]) -> int: # noqa required=True, hint="The x.", ), + ObjectField( + name="y", + schema=UnionNode( + [ + PrimitiveNode(primitive_type=float), + LiteralNode(options=["hehe"]), + ] + ), + required=True, + hint="Hard example.", + ), ], ), + long_description="Blah blah blah.", ) assert default_schema_extractor.extract_schema(bar) == bar_schema @@ -147,6 +181,7 @@ def __init__(self, x: int, y): ), ], ), + long_description=None, ) assert default_schema_extractor.extract_schema(Baz) == baz_schema @@ -186,6 +221,7 @@ class Qux: ), ], ), + long_description="I have attributes instead of arguments!", ) assert default_schema_extractor.extract_schema(Qux) == qux_schema @@ -258,6 +294,7 @@ class Goo: ), ], ), + long_description="I am a dataclass, and I use the stupid ReST docstring syntax!", ) assert default_schema_extractor.extract_schema(Goo) == goo_schema @@ -266,7 +303,7 @@ class Hoo(BaseModel): r"""I am Hoo. I am a Pydantic model! - And I use Numpy Doc format!. + And I use Numpy Doc format! Parameters ---------- @@ -309,6 +346,75 @@ class Hoo(BaseModel): ), ], ), + long_description="I am a Pydantic model!\nAnd I use Numpy Doc format!", ) assert default_schema_extractor.extract_schema(Hoo) == hoo_schema + + +def test_json_schema(): + @frozen + class Inner: + """Inner. + + Long description of Inner. + + Attributes: + x: The x. + """ + + x: int + + def foo( + a: int, + b: str, + c: Inner, + d: tuple[int, ...] = (), + e: typing.Literal["x", "y"] = "x", + f: str | int = 1, + ): + r"""Foo! + + Long description of foo. + + Args: + a: The a. + b: The b. + c: The c. + """ + + expected_json_schema = { + "name": "foo", + "description": "Foo!\n\nLong description of foo.", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "integer", "description": "The a."}, + "b": {"type": "string", "description": "The b."}, + "c": { + "type": "object", + "properties": {"x": {"type": "integer", "description": "The x."}}, + "required": ["x"], + "description": "The c.\n\nInner.", + }, + "d": { + "type": "array", + "items": {"type": "integer"}, + }, + "e": { + "type": "string", + "enum": ["x", "y"], + }, + "f": { + "type": ["integer", "string"], + }, + }, + "required": ["a", "b", "c"], + }, + } + + core_schema = default_schema_extractor.extract_schema(foo) + json_schema = core_schema.json_schema( + JsonSchemaFlavour.OPENAI, include_long_description=True + ) + assert json_schema == expected_json_schema diff --git a/wanga/schema/extractor.py b/wanga/schema/extractor.py index 69269be..bfade6a 100644 --- a/wanga/schema/extractor.py +++ b/wanga/schema/extractor.py @@ -1,7 +1,7 @@ import inspect from collections.abc import Callable from types import NoneType, UnionType -from typing import Any, get_args, get_origin +from typing import Any, Literal, Union, get_args, get_origin from attrs import define, field, frozen from docstring_parser import parse as parse_docstring @@ -10,6 +10,7 @@ from .normalize import normalize_annotation from .schema import ( CallableSchema, + LiteralNode, MappingNode, ObjectField, ObjectNode, @@ -30,6 +31,7 @@ @frozen class DocstringHints: object_hint: str | None + long_decsription: str | None param_to_hint: dict[str, str] @@ -45,7 +47,7 @@ class SchemaExtractor: If you want to add new extraction functions, use the `register_extract_fn` method. Attributes: - exctractor_functions: List of functions that take type annotation as an input + extractor_functions: List of functions that take type annotation as an input and try to produce the `CallableSchema`. """ @@ -66,6 +68,7 @@ def extract_hints(self, callable: Callable) -> DocstringHints: if docstring is not None: docstring = parse_docstring(docstring) object_hint = docstring.short_description + long_description = docstring.long_description param_to_hint = { param.arg_name: param.description for param in docstring.params @@ -73,14 +76,20 @@ def extract_hints(self, callable: Callable) -> DocstringHints: } else: object_hint = None + long_description = None param_to_hint = {} if isinstance(callable, type) and hasattr(callable, "__init__"): init_hints = self.extract_hints(callable.__init__) object_hint = object_hint or init_hints.object_hint + long_description = long_description or init_hints.long_decsription param_to_hint.update(init_hints.param_to_hint) - return DocstringHints(object_hint, param_to_hint) + return DocstringHints( + object_hint, + long_description, + param_to_hint, + ) def annotation_to_schema(self, annotation) -> SchemaNode: if annotation in [Any, None]: @@ -99,6 +108,22 @@ def annotation_to_schema(self, annotation) -> SchemaNode: init_schema = self.extract_schema(annotation) return init_schema.call_schema + if origin is Literal: + return LiteralNode(options=list(args)) + + if origin in [Union, UnionType]: + # Reader may think that the second check is unnecessary, since Unions should have been + # converted to `|` by the normalization step. Unfortunately, Literal[1] | str + # will evaluate to Union, and not UnionType, so we have to check against both, + # Union and UnionType. + arg_schemas = [] + for arg in args: + if arg is NoneType: + arg_schemas.append(None) + else: + arg_schemas.append(self.annotation_to_schema(arg)) + return UnionNode(options=arg_schemas) + # Normalization step has already converted all abstract classes to the corresponding concrete types. # So we can safely check against list and dict. if issubclass(origin, list): @@ -110,6 +135,11 @@ def annotation_to_schema(self, annotation) -> SchemaNode: item_schema=self.annotation_to_schema(args[0]), ) if issubclass(origin, tuple): + if len(args) == 2 and args[1] is Ellipsis: + return SequenceNode( + sequence_type=origin, + item_schema=self.annotation_to_schema(args[0]), + ) return TupleNode( tuple_type=origin, item_schemas=[self.annotation_to_schema(arg) for arg in args], @@ -123,14 +153,6 @@ def annotation_to_schema(self, annotation) -> SchemaNode: key_schema=self.annotation_to_schema(args[0]), value_schema=self.annotation_to_schema(args[1]), ) - if issubclass(origin, UnionType): - arg_schemas = [] - for arg in args: - if arg is NoneType: - arg_schemas.append(None) - else: - arg_schemas.append(self.annotation_to_schema(arg)) - return UnionNode(options=arg_schemas) raise ValueError(f"Unsupported type annotation: {annotation}") @@ -139,7 +161,9 @@ def extract_schema(self, callable: Callable) -> CallableSchema: try: return self._extract_schema_impl(callable) except Exception as e: - raise SchemaExtractionError(f"Failed to extract schema for {callable}: {e}") + raise SchemaExtractionError( + f"Failed to extract schema for {callable}" + ) from e def _extract_schema_impl(self, callable: Callable) -> CallableSchema: for fn in self.exctractor_fns: @@ -193,6 +217,7 @@ def _extract_schema_impl(self, callable: Callable) -> CallableSchema: hint=hints.object_hint, ), return_schema=return_schema, + long_description=hints.long_decsription, ) diff --git a/wanga/schema/extractor_fns.py b/wanga/schema/extractor_fns.py index e5467a8..814acce 100644 --- a/wanga/schema/extractor_fns.py +++ b/wanga/schema/extractor_fns.py @@ -29,9 +29,9 @@ def _get_datetime_arg(name: str, required: bool = True) -> ObjectField: def extract_datetime(annotation: TypeAnnotation) -> CallableSchema | None: date_fields = ["year", "month", "day"] - time_fields = ["hour", "minute", "second"] # We deliberaly omit microseconds + time_fields = ["hour", "minute", "second"] # We deliberately omit microseconds - delta_fields = ["days", "seconds"] # We deliberaly omit microseconds + delta_fields = ["days", "seconds"] # We deliberately omit microseconds fields = [] if annotation in [date, datetime]: @@ -52,4 +52,5 @@ def extract_datetime(annotation: TypeAnnotation) -> CallableSchema | None: hint=None, ), return_schema=UndefinedNode(original_annotation=None), + long_description=None, ) diff --git a/wanga/schema/jsonschema.py b/wanga/schema/jsonschema.py new file mode 100644 index 0000000..e60cfd23 --- /dev/null +++ b/wanga/schema/jsonschema.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from enum import Enum +from typing import Literal, TypeAlias, TypedDict + +__all__ = [ + "JsonSchemaFlavour", + "LeafJsonSchema", + "EnumJsonSchema", + "ObjectJsonSchema", + "AnthropicCallableSchema", + "OpenAICallableSchema", + "CallableJsonSchema", + "JsonSchema", + "ArrayJsonSchema", + "LeafTypeName", +] + + +LeafTypeName: TypeAlias = Literal["string", "number", "integer", "boolean", "null"] + + +class LeafJsonSchema(TypedDict, total=False): + type: list[LeafTypeName] | LeafTypeName + description: str + + +class EnumJsonSchema(TypedDict, total=False): + type: Literal["string"] + enum: list[str] + description: str + + +class ObjectJsonSchema(TypedDict, total=False): + type: Literal["object"] + properties: dict[str, JsonSchema] + required: list[str] + description: str + + +class ArrayJsonSchema(TypedDict, total=False): + type: Literal["array"] + items: JsonSchema + description: str + + +JsonSchema: TypeAlias = ( + LeafJsonSchema | EnumJsonSchema | ObjectJsonSchema | ArrayJsonSchema +) + + +class AnthropicCallableSchema(TypedDict, total=False): + name: str + description: str + input_schema: ObjectJsonSchema + + +class OpenAICallableSchema(TypedDict, total=False): + name: str + description: str + parameters: ObjectJsonSchema + + +class JsonSchemaFlavour(Enum): + r"""Top-level layout of the JSON schema as accepted by different LLMS.""" + + OPENAI = "openai" + ANTHROPIC = "anthropic" + + +CallableJsonSchema: TypeAlias = AnthropicCallableSchema | OpenAICallableSchema diff --git a/wanga/schema/normalize.py b/wanga/schema/normalize.py index e594fcb..eb537b4 100644 --- a/wanga/schema/normalize.py +++ b/wanga/schema/normalize.py @@ -2,7 +2,7 @@ import collections.abc import typing # noqa from types import NoneType, UnionType -from typing import Annotated, Union, get_args, get_origin +from typing import Annotated, Literal, Union, get_args, get_origin from .utils import TypeAnnotation @@ -42,58 +42,33 @@ def unpack_optional(annotation: TypeAnnotation) -> type[UnionType] | None: return _fold_or(result) -# Those aliases are automatically resolved by the `typing.get_origin`, so there is no -# direct need to handle them explicitly, by we still include them here for the reference -# purposes. -# -# GENERIC_ALIASES = { -# # Basic aliases -# typing.Dict: dict, -# typing.List: list, -# typing.Set: set, -# typing.FrozenSet: frozenset, -# typing.Tuple: tuple, -# typing.Type: type, -# # collections aliases -# typing.DefaultDict: collections.defaultdict, -# typing.OrderedDict: collections.OrderedDict, -# typing.ChainMap: collections.ChainMap, -# typing.Counter: collections.Counter, -# typing.Deque: collections.deque, -# # re aliases -# typing.Pattern: re.Pattern, -# typing.Match: re.Match, -# typing.Text: str, -# # collections.abc aliases -# typing.AbstractSet: collections.abc.Set, -# typing.ByteString: collections.abc.ByteString, -# typing.Collection: collections.abc.Collection, -# typing.Container: collections.abc.Container, -# typing.ItemsView: collections.abc.ItemsView, -# typing.KeysView: collections.abc.KeysView, -# typing.Mapping: collections.abc.Mapping, -# typing.MappingView: collections.abc.MappingView, -# typing.MutableMapping: collections.abc.MutableMapping, -# typing.MutableSequence: collections.abc.MutableSequence, -# typing.MutableSet: collections.abc.MutableSet, -# typing.Sequence: collections.abc.Sequence, -# typing.ValuesView: collections.abc.ValuesView, -# typing.Coroutine: collections.abc.Coroutine, -# typing.AsyncGenerator: collections.abc.AsyncGenerator, -# typing.AsyncIterable: collections.abc.AsyncIterable, -# typing.AsyncIterator: collections.abc.AsyncIterator, -# typing.Awaitable: collections.abc.Awaitable, -# typing.Generator: collections.abc.Generator, -# typing.Iterable: collections.abc.Iterable, -# typing.Iterator: collections.abc.Iterator, -# typing.Callable: collections.abc.Callable, -# typing.Hashable: collections.abc.Hashable, -# typing.Reversible: collections.abc.Reversible, -# typing.Sized: collections.abc.Sized, -# # contextlib aliases -# typing.ContextManager: contextlib.AbstractContextManager, -# typing.AsyncContextManager: contextlib.AbstractAsyncContextManager, -# } +def normalize_literals(annotation: TypeAnnotation) -> TypeAnnotation: + r"""Merges literals within unions. + + Examples: + >>> normalize_literals(typing.Literal[1] | typing.Literal[2]) + typing.Literal[1, 2] + >>> normalize_literals(typing.Literal[1] | typing.Literal[2] | str) + typing.Union[str, typing.Literal[1, 2]] + """ + origin = get_origin(annotation) + args = get_args(annotation) + if origin in [Literal, None]: + return annotation + args = tuple(normalize_literals(arg) for arg in args) + if origin in [Union, UnionType]: + literals = [] + non_literals = [] + for arg in args: + if get_origin(arg) is Literal: + literals.extend(get_args(arg)) + else: + non_literals.append(arg) + new_args = list(non_literals) + if literals: + new_args.append(Literal[tuple(literals)]) # type: ignore + return _fold_or(new_args) + return origin[args] ABSTRACT_TO_CONCRETE = { @@ -116,6 +91,28 @@ def unpack_optional(annotation: TypeAnnotation) -> type[UnionType] | None: } +def _normalize_annotation_rec( + annotation: TypeAnnotation, concretize: bool = False +) -> TypeAnnotation: + origin = get_origin(annotation) + args = get_args(annotation) + if origin is None: + return annotation + if args: + args = tuple( + _normalize_annotation_rec(arg, concretize=concretize) for arg in args + ) + if origin is Annotated: + return args[0] + if origin in [Union, UnionType]: + return _fold_or(args) + if concretize: + origin = ABSTRACT_TO_CONCRETE.get(origin, origin) + if args: + return origin[args] + return origin + + def normalize_annotation( annotation: TypeAnnotation, concretize: bool = False ) -> TypeAnnotation: @@ -137,19 +134,9 @@ def normalize_annotation( collections.abc.Sequence[int] >>> normalize_annotation(collections.abc.Sequence[int], concretize=True) list[int] + >>> normalize_annotation(typing.Literal[1] | typing.Literal[2]) + typing.Literal[1, 2] """ - origin = get_origin(annotation) - args = get_args(annotation) - if origin is None: - return annotation - if args: - args = tuple(normalize_annotation(arg, concretize=concretize) for arg in args) - if origin is Annotated: - return args[0] - if origin is Union: - return _fold_or(args) - if concretize: - origin = ABSTRACT_TO_CONCRETE.get(origin, origin) - if args: - return origin[args] - return origin + result = normalize_literals(annotation) + result = _normalize_annotation_rec(result, concretize=concretize) + return result diff --git a/wanga/schema/schema.py b/wanga/schema/schema.py index 34514ff..5f7a5d4 100644 --- a/wanga/schema/schema.py +++ b/wanga/schema/schema.py @@ -1,9 +1,21 @@ from collections.abc import Mapping, Sequence from types import NoneType -from typing import Callable, TypeAlias +from typing import Callable, Literal, TypeAlias from attrs import evolve, frozen +from .jsonschema import ( + AnthropicCallableSchema, + ArrayJsonSchema, + CallableJsonSchema, + EnumJsonSchema, + JsonSchema, + JsonSchemaFlavour, + LeafJsonSchema, + LeafTypeName, + ObjectJsonSchema, + OpenAICallableSchema, +) from .utils import TypeAnnotation __all__ = [ @@ -24,12 +36,29 @@ JSON: TypeAlias = int | float | str | None | dict[str, "JSON"] | list["JSON"] +type_to_jsonname: dict[type | None, LeafTypeName] = { + int: "integer", + float: "number", + str: "string", + bool: "boolean", + None: "null", +} + + +class JsonSchemaGenerationError(ValueError): + pass + + @frozen class SchemaNode: r"""Base class for schema nodes.""" - def json_schema(self) -> JSON: - r"""Returns the JSON schema of the node to use in the LLM function call APIs.""" + def json_schema(self, parent_hint: str | None = None) -> JsonSchema: + r"""Returns the JSON schema of the node to use in the LLM function call APIs. + + Args: + parent_hint: Hint from the parent object. + """ raise NotImplementedError @@ -44,6 +73,11 @@ class UndefinedNode(SchemaNode): original_annotation: NoneType | TypeAnnotation + def json_schema(self, parent_hint: str | None = None) -> LeafJsonSchema: + raise JsonSchemaGenerationError( + "JSON schema cannot be generated for missing or undefined annotations." + ) + @frozen class PrimitiveNode(SchemaNode): @@ -57,6 +91,14 @@ class PrimitiveNode(SchemaNode): primitive_type: type[int] | type[float] | type[str] | type[bool] + def json_schema(self, parent_hint: str | None = None) -> LeafJsonSchema: + result = LeafJsonSchema( + type=type_to_jsonname[self.primitive_type], + ) + if parent_hint: + result["description"] = parent_hint + return result + @frozen class SequenceNode(SchemaNode): @@ -71,6 +113,12 @@ class SequenceNode(SchemaNode): sequence_type: type[Sequence] item_schema: SchemaNode + def json_schema(self, parent_hint: str | None = None) -> ArrayJsonSchema: + result = ArrayJsonSchema(type="array", items=self.item_schema.json_schema()) + if parent_hint: + result["description"] = parent_hint + return result + @frozen class TupleNode(SchemaNode): @@ -84,6 +132,11 @@ class TupleNode(SchemaNode): tuple_type: type[tuple] item_schemas: list[SchemaNode] + def json_schema(self, parent_hint: str | None = None) -> JsonSchema: + raise JsonSchemaGenerationError( + "JSON schema cannot be generated for heterogeneous tuple types." + ) + @frozen class MappingNode(SchemaNode): @@ -99,6 +152,11 @@ class MappingNode(SchemaNode): key_schema: SchemaNode value_schema: SchemaNode + def json_schema(self, parent_hint: str | None = None) -> JsonSchema: + raise JsonSchemaGenerationError( + "JSON schema cannot be generated for Mapping types." + ) + @frozen class UnionNode(SchemaNode): @@ -111,6 +169,52 @@ class UnionNode(SchemaNode): options: list[SchemaNode | None] + def json_schema(self, parent_hint: str | None = None) -> JsonSchema: + if all( + option is None or isinstance(option, PrimitiveNode) + for option in self.options + ): + type_names = { + type_to_jsonname[option.primitive_type] # type: ignore + for option in self.options + if option is not None + } + if "number" in type_names and "integer" in type_names: + type_names.remove("integer") + type_names = list(type_names) + if len(type_names) == 1: + type_names = type_names[0] + result = LeafJsonSchema( + type=type_names, # type: ignore + ) + if parent_hint: + result["description"] = parent_hint + return result + raise JsonSchemaGenerationError( + "JSON schema cannot be generated for non-trivial Union types." + ) + + +@frozen +class LiteralNode(SchemaNode): + r"""Node corresponding to the `Literal` type. + + Attributes: + options: The value of the literal. + """ + + options: list[int | float | str | bool] + + def json_schema(self, parent_hint: str | None = None) -> EnumJsonSchema: + if not all(isinstance(option, str) for option in self.options): + raise JsonSchemaGenerationError( + "JSON schema can only be generated for string literal types." + ) + result = EnumJsonSchema(type="string", enum=self.options) # type: ignore + if parent_hint: + result["description"] = parent_hint + return result + @frozen class ObjectField: @@ -128,6 +232,9 @@ class ObjectField: hint: str | None required: bool + def json_schema(self) -> JsonSchema: + return self.schema.json_schema(parent_hint=self.hint) + @frozen class ObjectNode(SchemaNode): @@ -146,6 +253,22 @@ class ObjectNode(SchemaNode): fields: list[ObjectField] hint: str | None + def json_schema(self, parent_hint: str | None = None) -> ObjectJsonSchema: + result = ObjectJsonSchema( + type="object", + properties={field.name: field.json_schema() for field in self.fields}, + required=[field.name for field in self.fields if field.required], + ) + joint_hint = [] + if parent_hint: + joint_hint.append(parent_hint) + if self.hint: + joint_hint.append(self.hint) + joint_hint = "\n\n".join(joint_hint) + if joint_hint: + result["description"] = joint_hint + return result + @frozen class CallableSchema: @@ -155,16 +278,38 @@ class CallableSchema: call_schema: Schema of the function call. return_schema: Schema of the return value. None if the function returns None. + long_description: Long description extracted from the docstring. + It is used to pass tool descriptions to LLMs. It is not used + for return values. """ call_schema: ObjectNode return_schema: SchemaNode + long_description: str | None - def json_schema(self) -> JSON: - result = dict( - name=self.call_schema.name, - parameters=evolve(self.call_schema, hint=None).json_schema(), - ) + def json_schema( + self, flavour: JsonSchemaFlavour, include_long_description: bool = False + ) -> CallableJsonSchema: + full_description = [] if self.call_schema.hint: - result["description"] = self.call_schema.hint + full_description.append(self.call_schema.hint) + if self.long_description and include_long_description: + full_description.append(self.long_description) + full_description = "\n\n".join(full_description) + if flavour is JsonSchemaFlavour.ANTHROPIC: + result = AnthropicCallableSchema( + name=self.call_schema.name, + input_schema=evolve(self.call_schema, hint=None).json_schema(), + ) + if full_description: + result["description"] = full_description + elif flavour is JsonSchemaFlavour.OPENAI: + result = OpenAICallableSchema( + name=self.call_schema.name, + parameters=evolve(self.call_schema, hint=None).json_schema(), + ) + if full_description: + result["description"] = full_description + else: + raise ValueError(f"Unknown JSON schema flavour: {flavour}") return result diff --git a/wanga/schema/utils.py b/wanga/schema/utils.py index 35360a2..abd6503 100644 --- a/wanga/schema/utils.py +++ b/wanga/schema/utils.py @@ -4,6 +4,6 @@ "TypeAnnotation", ] -# Python doesn't have a way to speciy a type annotation for type annotations, +# Python doesn't have a way to specify a type annotation for type annotations, # so we use `Any` as a placeholder. TypeAnnotation: TypeAlias = Any