Skip to content

Commit

Permalink
Implement JSON Schema generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Jun 27, 2024
1 parent 8efa764 commit 24cde85
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 95 deletions.
112 changes: 109 additions & 3 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
from pydantic import BaseModel

from wanga.schema.extractor import default_schema_extractor
from wanga.schema.jsonschema import JsonSchemaFlavour
from wanga.schema.normalize import normalize_annotation, unpack_optional
from wanga.schema.schema import (
CallableSchema,
LiteralNode,
ObjectField,
ObjectNode,
PrimitiveNode,
SequenceNode,
UndefinedNode,
UnionNode,
)


Expand All @@ -28,6 +31,14 @@ def test_normalize_schema():
typing.Union[int, float]: int | float,
typing.Optional[int]: int | None,
typing.List: list,
typing.Union[typing.Union[int, float], str]: int | float | str,
(typing.Literal[1] | typing.Literal[2] | typing.Literal[3]): typing.Literal[
1, 2, 3
],
(
typing.Literal[1, 2]
| typing.Union[typing.Literal[2, 3], typing.Literal[3, 4]]
): (typing.Literal[1, 2, 3, 4]),
}
for annotation, result in expected.items():
assert normalize_annotation(annotation) == result
Expand Down Expand Up @@ -55,7 +66,7 @@ def test_concretize_schema():


def test_extract_schema():
def foo(x: int, y: str = "hello"): # noqa
def foo(x: int, y: str = "hello", z: tuple[int, ...] = ()): # noqa
pass

foo_schema = CallableSchema(
Expand All @@ -77,19 +88,30 @@ def foo(x: int, y: str = "hello"): # noqa
required=False,
hint=None,
),
ObjectField(
name="z",
schema=SequenceNode(
sequence_type=tuple,
item_schema=PrimitiveNode(primitive_type=int),
),
required=False,
hint=None,
),
],
),
long_description=None,
)

assert default_schema_extractor.extract_schema(foo) == foo_schema

def bar(x: typing.List[int]) -> int: # noqa
def bar(x: typing.List[int], y: typing.Literal["hehe"] | float) -> int: # noqa
r"""Bar.
Blah blah blah.
Args:
x: The x.
y: Hard example.
"""
return 0

Expand All @@ -109,8 +131,20 @@ def bar(x: typing.List[int]) -> int: # noqa
required=True,
hint="The x.",
),
ObjectField(
name="y",
schema=UnionNode(
[
PrimitiveNode(primitive_type=float),
LiteralNode(options=["hehe"]),
]
),
required=True,
hint="Hard example.",
),
],
),
long_description="Blah blah blah.",
)

assert default_schema_extractor.extract_schema(bar) == bar_schema
Expand Down Expand Up @@ -147,6 +181,7 @@ def __init__(self, x: int, y):
),
],
),
long_description=None,
)

assert default_schema_extractor.extract_schema(Baz) == baz_schema
Expand Down Expand Up @@ -186,6 +221,7 @@ class Qux:
),
],
),
long_description="I have attributes instead of arguments!",
)

assert default_schema_extractor.extract_schema(Qux) == qux_schema
Expand Down Expand Up @@ -258,6 +294,7 @@ class Goo:
),
],
),
long_description="I am a dataclass, and I use the stupid ReST docstring syntax!",
)

assert default_schema_extractor.extract_schema(Goo) == goo_schema
Expand All @@ -266,7 +303,7 @@ class Hoo(BaseModel):
r"""I am Hoo.
I am a Pydantic model!
And I use Numpy Doc format!.
And I use Numpy Doc format!
Parameters
----------
Expand Down Expand Up @@ -309,6 +346,75 @@ class Hoo(BaseModel):
),
],
),
long_description="I am a Pydantic model!\nAnd I use Numpy Doc format!",
)

assert default_schema_extractor.extract_schema(Hoo) == hoo_schema


def test_json_schema():
@frozen
class Inner:
"""Inner.
Long description of Inner.
Attributes:
x: The x.
"""

x: int

def foo(
a: int,
b: str,
c: Inner,
d: tuple[int, ...] = (),
e: typing.Literal["x", "y"] = "x",
f: str | int = 1,
):
r"""Foo!
Long description of foo.
Args:
a: The a.
b: The b.
c: The c.
"""

expected_json_schema = {
"name": "foo",
"description": "Foo!\n\nLong description of foo.",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The a."},
"b": {"type": "string", "description": "The b."},
"c": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The x."}},
"required": ["x"],
"description": "The c.\n\nInner.",
},
"d": {
"type": "array",
"items": {"type": "integer"},
},
"e": {
"type": "string",
"enum": ["x", "y"],
},
"f": {
"type": ["integer", "string"],
},
},
"required": ["a", "b", "c"],
},
}

core_schema = default_schema_extractor.extract_schema(foo)
json_schema = core_schema.json_schema(
JsonSchemaFlavour.OPENAI, include_long_description=True
)
assert json_schema == expected_json_schema
49 changes: 37 additions & 12 deletions wanga/schema/extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from collections.abc import Callable
from types import NoneType, UnionType
from typing import Any, get_args, get_origin
from typing import Any, Literal, Union, get_args, get_origin

from attrs import define, field, frozen
from docstring_parser import parse as parse_docstring
Expand All @@ -10,6 +10,7 @@
from .normalize import normalize_annotation
from .schema import (
CallableSchema,
LiteralNode,
MappingNode,
ObjectField,
ObjectNode,
Expand All @@ -30,6 +31,7 @@
@frozen
class DocstringHints:
object_hint: str | None
long_decsription: str | None
param_to_hint: dict[str, str]


Expand All @@ -45,7 +47,7 @@ class SchemaExtractor:
If you want to add new extraction functions, use the `register_extract_fn` method.
Attributes:
exctractor_functions: List of functions that take type annotation as an input
extractor_functions: List of functions that take type annotation as an input
and try to produce the `CallableSchema`.
"""

Expand All @@ -66,21 +68,28 @@ def extract_hints(self, callable: Callable) -> DocstringHints:
if docstring is not None:
docstring = parse_docstring(docstring)
object_hint = docstring.short_description
long_description = docstring.long_description
param_to_hint = {
param.arg_name: param.description
for param in docstring.params
if param.description
}
else:
object_hint = None
long_description = None
param_to_hint = {}

if isinstance(callable, type) and hasattr(callable, "__init__"):
init_hints = self.extract_hints(callable.__init__)
object_hint = object_hint or init_hints.object_hint
long_description = long_description or init_hints.long_decsription
param_to_hint.update(init_hints.param_to_hint)

return DocstringHints(object_hint, param_to_hint)
return DocstringHints(
object_hint,
long_description,
param_to_hint,
)

def annotation_to_schema(self, annotation) -> SchemaNode:
if annotation in [Any, None]:
Expand All @@ -99,6 +108,22 @@ def annotation_to_schema(self, annotation) -> SchemaNode:
init_schema = self.extract_schema(annotation)
return init_schema.call_schema

if origin is Literal:
return LiteralNode(options=list(args))

if origin in [Union, UnionType]:
# Reader may think that the second check is unnecessary, since Unions should have been
# converted to `|` by the normalization step. Unfortunately, Literal[1] | str
# will evaluate to Union, and not UnionType, so we have to check against both,
# Union and UnionType.
arg_schemas = []
for arg in args:
if arg is NoneType:
arg_schemas.append(None)
else:
arg_schemas.append(self.annotation_to_schema(arg))
return UnionNode(options=arg_schemas)

# Normalization step has already converted all abstract classes to the corresponding concrete types.
# So we can safely check against list and dict.
if issubclass(origin, list):
Expand All @@ -110,6 +135,11 @@ def annotation_to_schema(self, annotation) -> SchemaNode:
item_schema=self.annotation_to_schema(args[0]),
)
if issubclass(origin, tuple):
if len(args) == 2 and args[1] is Ellipsis:
return SequenceNode(
sequence_type=origin,
item_schema=self.annotation_to_schema(args[0]),
)
return TupleNode(
tuple_type=origin,
item_schemas=[self.annotation_to_schema(arg) for arg in args],
Expand All @@ -123,14 +153,6 @@ def annotation_to_schema(self, annotation) -> SchemaNode:
key_schema=self.annotation_to_schema(args[0]),
value_schema=self.annotation_to_schema(args[1]),
)
if issubclass(origin, UnionType):
arg_schemas = []
for arg in args:
if arg is NoneType:
arg_schemas.append(None)
else:
arg_schemas.append(self.annotation_to_schema(arg))
return UnionNode(options=arg_schemas)

raise ValueError(f"Unsupported type annotation: {annotation}")

Expand All @@ -139,7 +161,9 @@ def extract_schema(self, callable: Callable) -> CallableSchema:
try:
return self._extract_schema_impl(callable)
except Exception as e:
raise SchemaExtractionError(f"Failed to extract schema for {callable}: {e}")
raise SchemaExtractionError(
f"Failed to extract schema for {callable}"
) from e

def _extract_schema_impl(self, callable: Callable) -> CallableSchema:
for fn in self.exctractor_fns:
Expand Down Expand Up @@ -193,6 +217,7 @@ def _extract_schema_impl(self, callable: Callable) -> CallableSchema:
hint=hints.object_hint,
),
return_schema=return_schema,
long_description=hints.long_decsription,
)


Expand Down
5 changes: 3 additions & 2 deletions wanga/schema/extractor_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def _get_datetime_arg(name: str, required: bool = True) -> ObjectField:

def extract_datetime(annotation: TypeAnnotation) -> CallableSchema | None:
date_fields = ["year", "month", "day"]
time_fields = ["hour", "minute", "second"] # We deliberaly omit microseconds
time_fields = ["hour", "minute", "second"] # We deliberately omit microseconds

delta_fields = ["days", "seconds"] # We deliberaly omit microseconds
delta_fields = ["days", "seconds"] # We deliberately omit microseconds

fields = []
if annotation in [date, datetime]:
Expand All @@ -52,4 +52,5 @@ def extract_datetime(annotation: TypeAnnotation) -> CallableSchema | None:
hint=None,
),
return_schema=UndefinedNode(original_annotation=None),
long_description=None,
)
Loading

0 comments on commit 24cde85

Please sign in to comment.