Skip to content

Commit

Permalink
Change maximum line width to 120 symbols.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Jun 28, 2024
1 parent 681e8a7 commit f219ee3
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 93 deletions.
4 changes: 1 addition & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def read_poetry_config(filename: Path | str) -> PoetryConfig:
filename = Path(filename)
with open(filename, "r") as file:
config = toml.load(file)["tool"]["poetry"]
return PoetryConfig(
project=config["name"], version=config["version"], author=config["authors"][0]
)
return PoetryConfig(project=config["name"], version=config["version"], author=config["authors"][0])


poetry_config = read_poetry_config(project_root / "pyproject.toml")
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@ sphinx-rtd-theme = "^2.0.0"
sphinxcontrib-napoleon = "^0.7"
toml = "^0.10.2"

[tool.pytest.ini_options]
addopts = "--doctest-modules"

[tool.pyright]
pythonVersion = "3.10"
typeCheckingMode = "standard"

[tool.ruff]
target-version = "py310"
line-length = 120

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
2 changes: 0 additions & 2 deletions pytest.ini

This file was deleted.

21 changes: 5 additions & 16 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,8 @@ def test_normalize_schema():
typing.Optional[int]: int | None,
typing.List: list,
typing.Union[typing.Union[int, float], str]: int | float | str,
(typing.Literal[1] | typing.Literal[2] | typing.Literal[3]): typing.Literal[
1, 2, 3
],
(
typing.Literal[1, 2]
| typing.Union[typing.Literal[2, 3], typing.Literal[3, 4]]
): (typing.Literal[1, 2, 3, 4]),
(typing.Literal[1] | typing.Literal[2] | typing.Literal[3]): typing.Literal[1, 2, 3],
(typing.Literal[1, 2] | typing.Union[typing.Literal[2, 3], typing.Literal[3, 4]]): (typing.Literal[1, 2, 3, 4]),
}
for annotation, result in expected.items():
assert normalize_annotation(annotation) == result
Expand Down Expand Up @@ -255,9 +250,7 @@ class Goo:
name="date",
schema=ObjectNode(
constructor_fn=datetime,
constructor_signature=strip_self(
inspect.signature(datetime.__init__)
),
constructor_signature=strip_self(inspect.signature(datetime.__init__)),
name="datetime",
hint=None,
fields=[
Expand Down Expand Up @@ -335,9 +328,7 @@ class Hoo(BaseModel):
name="delta",
schema=ObjectNode(
constructor_fn=timedelta,
constructor_signature=strip_self(
inspect.signature(timedelta.__init__)
),
constructor_signature=strip_self(inspect.signature(timedelta.__init__)),
name="timedelta",
hint=None,
fields=[
Expand Down Expand Up @@ -428,9 +419,7 @@ def foo(
}

core_schema = default_schema_extractor.extract_schema(foo)
json_schema = core_schema.json_schema(
JsonSchemaFlavor.OPENAI, include_long_description=True
)
json_schema = core_schema.json_schema(JsonSchemaFlavor.OPENAI, include_long_description=True)
assert json_schema == expected_json_schema


Expand Down
18 changes: 4 additions & 14 deletions wanga/schema/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,7 @@ def extract_hints(self, callable: Callable) -> DocstringHints:
docstring = parse_docstring(docstring)
object_hint = docstring.short_description
long_description = docstring.long_description
param_to_hint = {
param.arg_name: param.description
for param in docstring.params
if param.description
}
param_to_hint = {param.arg_name: param.description for param in docstring.params if param.description}
else:
object_hint = None
long_description = None
Expand Down Expand Up @@ -129,9 +125,7 @@ def annotation_to_schema(self, annotation) -> SchemaNode:
# Normalization step has already converted all abstract classes to the corresponding concrete types.
# So we can safely check against list and dict.
if issubclass(origin, list):
assert (
len(args) == 1
), "Sequence type annotation should have exactly one argument."
assert len(args) == 1, "Sequence type annotation should have exactly one argument."
return SequenceNode(
sequence_type=origin,
item_schema=self.annotation_to_schema(args[0]),
Expand All @@ -147,9 +141,7 @@ def annotation_to_schema(self, annotation) -> SchemaNode:
item_schemas=[self.annotation_to_schema(arg) for arg in args],
)
if issubclass(origin, dict):
assert (
len(args) == 2
), "Mapping type annotation should have exactly two arguments."
assert len(args) == 2, "Mapping type annotation should have exactly two arguments."
return MappingNode(
mapping_type=origin,
key_schema=self.annotation_to_schema(args[0]),
Expand All @@ -163,9 +155,7 @@ def extract_schema(self, callable: Callable) -> CallableSchema:
try:
return self._extract_schema_impl(callable)
except Exception as e:
raise SchemaExtractionError(
f"Failed to extract schema for {callable}"
) from e
raise SchemaExtractionError(f"Failed to extract schema for {callable}") from e

def _extract_schema_impl(self, callable: Callable) -> CallableSchema:
for fn in self.exctractor_fns:
Expand Down
4 changes: 1 addition & 3 deletions wanga/schema/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ class ArrayJsonSchema(TypedDict, total=False):
description: str


JsonSchema: TypeAlias = (
LeafJsonSchema | EnumJsonSchema | ObjectJsonSchema | ArrayJsonSchema
)
JsonSchema: TypeAlias = LeafJsonSchema | EnumJsonSchema | ObjectJsonSchema | ArrayJsonSchema


class AnthropicCallableSchema(TypedDict, total=False):
Expand Down
12 changes: 3 additions & 9 deletions wanga/schema/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,13 @@ def normalize_literals(annotation: TypeAnnotation) -> TypeAnnotation:
}


def _normalize_annotation_rec(
annotation: TypeAnnotation, concretize: bool = False
) -> TypeAnnotation:
def _normalize_annotation_rec(annotation: TypeAnnotation, concretize: bool = False) -> TypeAnnotation:
origin = get_origin(annotation)
args = get_args(annotation)
if origin is None:
return annotation
if args:
args = tuple(
_normalize_annotation_rec(arg, concretize=concretize) for arg in args
)
args = tuple(_normalize_annotation_rec(arg, concretize=concretize) for arg in args)
if origin is Annotated:
return args[0]
if origin in [Union, UnionType]:
Expand All @@ -113,9 +109,7 @@ def _normalize_annotation_rec(
return origin


def normalize_annotation(
annotation: TypeAnnotation, concretize: bool = False
) -> TypeAnnotation:
def normalize_annotation(annotation: TypeAnnotation, concretize: bool = False) -> TypeAnnotation:
r"""Normalize a type annotation to a standard form.
Strips `Annotated` tags and replaces generic aliases with corresponding generic types.
Expand Down
50 changes: 12 additions & 38 deletions wanga/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ class UndefinedNode(SchemaNode):
original_annotation: NoneType | TypeAnnotation

def json_schema(self, parent_hint: str | None = None) -> LeafJsonSchema:
raise UnsupportedSchemaError(
"JSON schema cannot be generated for missing or undefined annotations."
)
raise UnsupportedSchemaError("JSON schema cannot be generated for missing or undefined annotations.")

def eval(self, value: JSON) -> NoReturn:
raise UnsupportedSchemaError("Cannot evaluate undefined schema.")
Expand Down Expand Up @@ -118,9 +116,7 @@ def eval(self, value: JSON) -> int | float | str | bool:
if self.primitive_type is float and isinstance(value, int):
return float(value)
else:
raise SchemaValidationError(
f"Expected {self.primitive_type}, got {value}"
)
raise SchemaValidationError(f"Expected {self.primitive_type}, got {value}")
return value


Expand Down Expand Up @@ -162,21 +158,14 @@ class TupleNode(SchemaNode):
item_schemas: list[SchemaNode]

def json_schema(self, parent_hint: str | None = None) -> JsonSchema:
raise UnsupportedSchemaError(
"JSON schema cannot be generated for heterogeneous tuple types."
)
raise UnsupportedSchemaError("JSON schema cannot be generated for heterogeneous tuple types.")

def eval(self, value: JSON) -> tuple:
if not isinstance(value, list):
raise SchemaValidationError(f"Expected list, got {value}")
if len(value) != len(self.item_schemas):
raise SchemaValidationError(
f"Expected tuple of length {len(self.item_schemas)}, got {len(value)}"
)
return tuple(
item_schema.eval(item)
for item_schema, item in zip(self.item_schemas, value)
)
raise SchemaValidationError(f"Expected tuple of length {len(self.item_schemas)}, got {len(value)}")
return tuple(item_schema.eval(item) for item_schema, item in zip(self.item_schemas, value))


@frozen
Expand All @@ -194,9 +183,7 @@ class MappingNode(SchemaNode):
value_schema: SchemaNode

def json_schema(self, parent_hint: str | None = None) -> JsonSchema:
raise UnsupportedSchemaError(
"JSON schema cannot be generated for Mapping types."
)
raise UnsupportedSchemaError("JSON schema cannot be generated for Mapping types.")

def eval(self, value: JSON) -> NoReturn:
raise UnsupportedSchemaError("Cannot evaluate Mapping schema.")
Expand All @@ -215,16 +202,11 @@ class UnionNode(SchemaNode):

@property
def is_primitive(self) -> bool:
return all(
option is None or isinstance(option, PrimitiveNode)
for option in self.options
)
return all(option is None or isinstance(option, PrimitiveNode) for option in self.options)

def json_schema(self, parent_hint: str | None = None) -> JsonSchema:
if not self.is_primitive:
raise UnsupportedSchemaError(
"JSON schema cannot be generated for non-trivial Union types."
)
raise UnsupportedSchemaError("JSON schema cannot be generated for non-trivial Union types.")
type_names = [
_type_to_jsonname[option.primitive_type] # type: ignore
for option in self.options
Expand Down Expand Up @@ -254,9 +236,7 @@ def eval(self, value: JSON) -> JSON:
return option.eval(value)
except SchemaValidationError:
continue
raise SchemaValidationError(
f"Value {value} does not match any of the options: {self.options}"
)
raise SchemaValidationError(f"Value {value} does not match any of the options: {self.options}")


@frozen
Expand All @@ -271,19 +251,15 @@ class LiteralNode(SchemaNode):

def json_schema(self, parent_hint: str | None = None) -> EnumJsonSchema:
if not all(isinstance(option, str) for option in self.options):
raise UnsupportedSchemaError(
"JSON schema can only be generated for string literal types."
)
raise UnsupportedSchemaError("JSON schema can only be generated for string literal types.")
result = EnumJsonSchema(type="string", enum=self.options) # type: ignore
if parent_hint:
result["description"] = parent_hint
return result

def eval(self, value: JSON) -> int | float | str | bool:
if value not in self.options:
raise SchemaValidationError(
f"Value {value} does not match any of the options: {self.options}"
)
raise SchemaValidationError(f"Value {value} does not match any of the options: {self.options}")
return value


Expand Down Expand Up @@ -387,9 +363,7 @@ class CallableSchema:
return_schema: SchemaNode
long_description: str | None

def json_schema(
self, flavor: JsonSchemaFlavor, include_long_description: bool = False
) -> CallableJsonSchema:
def json_schema(self, flavor: JsonSchemaFlavor, include_long_description: bool = False) -> CallableJsonSchema:
r"""Extract JSON Schema to use with the LLM function call APIs.
Args:
Expand Down
6 changes: 1 addition & 5 deletions wanga/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,4 @@


def strip_self(signature: Signature) -> Signature:
return signature.replace(
parameters=[
param for name, param in signature.parameters.items() if name != "self"
]
)
return signature.replace(parameters=[param for name, param in signature.parameters.items() if name != "self"])
4 changes: 1 addition & 3 deletions wanga/templates/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@


ENVIROMENT = jinja2.Environment(autoescape=False, undefined=jinja2.StrictUndefined)
ENVIROMENT.globals.update(
{function.__name__: function for function in BUILTIN_FUNCTIONS}
)
ENVIROMENT.globals.update({function.__name__: function for function in BUILTIN_FUNCTIONS})


def make_template(template_string: str) -> jinja2.Template:
Expand Down

0 comments on commit f219ee3

Please sign in to comment.