Skip to content

Commit

Permalink
Implement schema evaluation.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Jun 28, 2024
1 parent 826a6e9 commit d37c3d4
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 44 deletions.
44 changes: 42 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import collections.abc
import inspect
import typing
from dataclasses import dataclass
from datetime import datetime, timedelta
Expand All @@ -8,7 +9,7 @@
from pydantic import BaseModel

from wanga.schema.extractor import default_schema_extractor
from wanga.schema.jsonschema import JsonSchemaFlavour
from wanga.schema.jsonschema import JsonSchemaFlavor
from wanga.schema.normalize import normalize_annotation, unpack_optional
from wanga.schema.schema import (
CallableSchema,
Expand All @@ -20,6 +21,7 @@
UndefinedNode,
UnionNode,
)
from wanga.schema.utils import strip_self


def test_normalize_schema():
Expand Down Expand Up @@ -73,6 +75,7 @@ def foo(x: int, y: str = "hello", z: tuple[int, ...] = ()): # noqa
return_schema=UndefinedNode(original_annotation=typing.Any),
call_schema=ObjectNode(
constructor_fn=foo,
constructor_signature=inspect.signature(foo),
name="foo",
hint=None,
fields=[
Expand Down Expand Up @@ -119,6 +122,7 @@ def bar(x: typing.List[int], y: typing.Literal["hehe"] | float) -> int: # noqa
return_schema=PrimitiveNode(primitive_type=int),
call_schema=ObjectNode(
constructor_fn=bar,
constructor_signature=inspect.signature(bar),
name="bar",
hint="Bar.",
fields=[
Expand Down Expand Up @@ -164,6 +168,7 @@ def __init__(self, x: int, y):
return_schema=UndefinedNode(original_annotation=typing.Any),
call_schema=ObjectNode(
constructor_fn=Baz,
constructor_signature=inspect.signature(Baz),
name="Baz",
hint="I am Baz.",
fields=[
Expand Down Expand Up @@ -204,6 +209,7 @@ class Qux:
return_schema=UndefinedNode(original_annotation=None),
call_schema=ObjectNode(
constructor_fn=Qux,
constructor_signature=inspect.signature(Qux),
name="Qux",
hint="I am Qux.",
fields=[
Expand Down Expand Up @@ -241,13 +247,17 @@ class Goo:
return_schema=UndefinedNode(original_annotation=None),
call_schema=ObjectNode(
constructor_fn=Goo,
constructor_signature=inspect.signature(Goo),
name="Goo",
hint="I am Goo.",
fields=[
ObjectField(
name="date",
schema=ObjectNode(
constructor_fn=datetime,
constructor_signature=strip_self(
inspect.signature(datetime.__init__)
),
name="datetime",
hint=None,
fields=[
Expand Down Expand Up @@ -317,13 +327,17 @@ class Hoo(BaseModel):
return_schema=UndefinedNode(original_annotation=None),
call_schema=ObjectNode(
constructor_fn=Hoo,
constructor_signature=inspect.signature(Hoo),
name="Hoo",
hint="I am Hoo.",
fields=[
ObjectField(
name="delta",
schema=ObjectNode(
constructor_fn=timedelta,
constructor_signature=strip_self(
inspect.signature(timedelta.__init__)
),
name="timedelta",
hint=None,
fields=[
Expand Down Expand Up @@ -415,6 +429,32 @@ def foo(

core_schema = default_schema_extractor.extract_schema(foo)
json_schema = core_schema.json_schema(
JsonSchemaFlavour.OPENAI, include_long_description=True
JsonSchemaFlavor.OPENAI, include_long_description=True
)
assert json_schema == expected_json_schema


def test_eval():
@frozen
class Hehe:
hehehe: int

def foo(
x: float,
/,
y: int = 3,
*,
z: typing.Literal["a", "b"],
hehe: Hehe | None = None,
):
return x

json_input = {
"x": 1,
"y": 2,
"z": "a",
"hehe": {"hehehe": 3},
}

core_schema = default_schema_extractor.extract_schema(foo)
assert core_schema.eval(json_input) == 1.0
11 changes: 4 additions & 7 deletions wanga/schema/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from attrs import define, field, frozen
from docstring_parser import parse as parse_docstring

from wanga.schema.utils import strip_self

from .extractor_fns import ExtractorFn, extract_datetime
from .normalize import normalize_annotation
from .schema import (
Expand Down Expand Up @@ -178,13 +180,7 @@ def _extract_schema_impl(self, callable: Callable) -> CallableSchema:
# correctly by `inspect.signature`. In such cases, we fall back to
# `__init__` signature, but we still use the original docstring for hints.
signature = inspect.signature(callable.__init__, eval_str=True)
signature = signature.replace(
parameters=[
param
for name, param in signature.parameters.items()
if name != "self"
]
)
signature = strip_self(signature)

return_type = signature.return_annotation

Expand Down Expand Up @@ -212,6 +208,7 @@ def _extract_schema_impl(self, callable: Callable) -> CallableSchema:
return CallableSchema(
call_schema=ObjectNode(
constructor_fn=callable,
constructor_signature=signature,
name=callable.__name__,
fields=object_fields,
hint=hints.object_hint,
Expand Down
4 changes: 3 additions & 1 deletion wanga/schema/extractor_fns.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from datetime import date, datetime, time, timedelta
from typing import Callable, TypeAlias

Expand All @@ -8,7 +9,7 @@
PrimitiveNode,
UndefinedNode,
)
from .utils import TypeAnnotation
from .utils import TypeAnnotation, strip_self

ExtractorFn: TypeAlias = Callable[[TypeAnnotation], CallableSchema | None]

Expand Down Expand Up @@ -47,6 +48,7 @@ def extract_datetime(annotation: TypeAnnotation) -> CallableSchema | None:
return CallableSchema(
call_schema=ObjectNode(
constructor_fn=annotation,
constructor_signature=strip_self(inspect.signature(annotation.__init__)),
name=annotation.__name__,
fields=fields,
hint=None,
Expand Down
4 changes: 2 additions & 2 deletions wanga/schema/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Literal, TypeAlias, TypedDict

__all__ = [
"JsonSchemaFlavour",
"JsonSchemaFlavor",
"LeafJsonSchema",
"EnumJsonSchema",
"ObjectJsonSchema",
Expand Down Expand Up @@ -61,7 +61,7 @@ class OpenAICallableSchema(TypedDict, total=False):
parameters: ObjectJsonSchema


class JsonSchemaFlavour(Enum):
class JsonSchemaFlavor(Enum):
r"""Top-level layout of the JSON schema as accepted by different LLMS."""

OPENAI = "openai"
Expand Down
Loading

0 comments on commit d37c3d4

Please sign in to comment.