Skip to content

Commit

Permalink
patterns, minlen,maxlen for pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Oct 10, 2024
1 parent bd5824d commit 75f7636
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 8 deletions.
42 changes: 39 additions & 3 deletions src/formatron/formats/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,39 @@ def schema(current: typing.Type, nonterminal: str):

def field_info(current: typing.Type, nonterminal: str):
if isinstance(current, schemas.schema.FieldInfo):
annotation = current.annotation
if current.required:
return "", [(current.annotation, nonterminal)]
return "", [(annotation, nonterminal)]
new_nonterminal = f"{nonterminal}_required"
return f"{nonterminal} ::= {new_nonterminal}?;\n", [(current.annotation, new_nonterminal)]
return f"{nonterminal} ::= {new_nonterminal}?;\n", [(annotation, new_nonterminal)]
return None

def string_metadata(current: typing.Type, nonterminal: str):
min_length = current.metadata.get("min_length")
max_length = current.metadata.get("max_length")
pattern = current.metadata.get("pattern")
if pattern:
assert not (min_length or max_length), "pattern is mutually exclusive with min_length and max_length"
repetition = None
if min_length is not None and max_length is None:
repetition = f"{{{min_length},}}"
elif min_length is None and max_length is not None:
repetition = f"{{0,{max_length}}}"
elif min_length is not None and max_length is not None:
repetition = f"{{{min_length},{max_length}}}"
if repetition is not None:
return fr"""{nonterminal} ::= #'"([^\\\\"\u0000-\u001f]|\\\\["\\\\bfnrt/]|\\\\u[0-9A-Fa-f]{{4}}){repetition}"';
""", []
if pattern is not None:
pattern = pattern.replace("'", "\\'")
return f"""{nonterminal} ::= #'"{pattern}"';\n""", []

def metadata(current: typing.Type, nonterminal: str):
if isinstance(current, schemas.schema.TypeWithMetadata):
if not current.metadata:
return "", [(current.type, nonterminal)]
if isinstance(current.type, type) and issubclass(current.type, str):
return string_metadata(current, nonterminal)
return None

def builtin_list(current: typing.Type, nonterminal: str):
Expand Down Expand Up @@ -189,6 +218,7 @@ def builtin_simple_types(current: typing.Type, nonterminal: str):
register_generate_nonterminal_def(builtin_simple_types)
register_generate_nonterminal_def(schema)
register_generate_nonterminal_def(field_info)
register_generate_nonterminal_def(metadata)
register_generate_nonterminal_def(builtin_tuple)
register_generate_nonterminal_def(builtin_literal)
register_generate_nonterminal_def(builtin_union)
Expand Down Expand Up @@ -255,7 +285,13 @@ def __init__(self, nonterminal: str, capture_name: typing.Optional[str], schema:
- bool
- int
- float
- string
- str
- with min_length, max_length and pattern constraints
- length is measured in UTF-8 character number
- *Warning*: too large difference between min_length and max_length can lead to enormous memory consumption!
- pattern is mutually exclusive with min_length and max_length
- pattern will be compiled to a regular expression so all caveats of regular expressions apply
- pattern currently is automatically anchored at both ends
- NoneType
- typing.Any
- Subclasses of collections.abc.Mapping[str,T] and typing.Mapping[str,T] where T is a supported type,
Expand Down
13 changes: 11 additions & 2 deletions src/formatron/schemas/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pydantic.fields
from pydantic import BaseModel, validate_call, ConfigDict, Field

from formatron.schemas.schema import FieldInfo, Schema
from formatron.schemas.schema import FieldInfo, Schema, TypeWithMetadata


class FieldInfo(FieldInfo):
Expand All @@ -22,10 +22,19 @@ def __init__(self, field: pydantic.fields.FieldInfo):
Initialize the field information.
"""
self._field = field
self._annotation = field.annotation
if field.metadata:
metadata = {}
for constraint in ["min_length", "max_length", "pattern"]:
value = next((getattr(m, constraint) for m in self._field.metadata if hasattr(m, constraint)), None)
if value is not None:
metadata[constraint] = value
if metadata:
self._annotation = TypeWithMetadata(self._annotation, metadata)

@property
def annotation(self) -> typing.Type[typing.Any] | None:
return self._field.annotation
return self._annotation

@property
def required(self) -> bool:
Expand Down
12 changes: 12 additions & 0 deletions src/formatron/schemas/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ def required(self) -> bool:
"""
pass

class TypeWithMetadata:
def __init__(self, type: typing.Type[typing.Any], metadata: dict[str, typing.Any]|None):
self._type = type
self._metadata = metadata

@property
def type(self) -> typing.Type[typing.Any]:
return self._type

@property
def metadata(self) -> dict[str, typing.Any]|None:
return self._metadata

class Schema(abc.ABC):
"""
Expand Down
35 changes: 32 additions & 3 deletions tests/snapshots/snap_test_grammar_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
start_concepts_value ::= string;
start_related_queries ::= array_begin (start_related_queries_value (comma start_related_queries_value)*)? array_end;
start_related_queries_value ::= start_related_queries_value_0 | start_related_queries_value_1;
start_related_queries_value_1 ::= string;
start_related_queries_value_0 ::= object_begin \'"foo"\' colon start_related_queries_value_0_foo object_end;
start_related_queries_value_0_foo ::= integer;
start_related_queries_value_1 ::= object_begin \'"foo"\' colon start_related_queries_value_1_foo object_end;
start_related_queries_value_1_foo ::= integer;
start_related_queries_value_0 ::= string;
start_queries ::= array_begin (start_queries_value (comma start_queries_value)*)? array_end;
start_queries_value ::= start_queries_value_0 | start_queries_value_1 | start_queries_value_2;
start_queries_value_2 ::= boolean;
Expand Down Expand Up @@ -193,6 +193,35 @@
start_value ::= integer;
'''

snapshots['test_pydantic_string_constraints 1'] = '''integer ::= #"-?(0|[1-9]\\\\d*)";
number ::= #"-?(0|[1-9]\\\\d*)(\\\\.\\\\d+)?([eE][+-]?\\\\d+)?";
string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\';
boolean ::= "true"|"false";
null ::= "null";
array ::= array_begin (json_value (comma json_value)*)? array_end;
object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end;
json_value ::= number|string|boolean|null|array|object;
comma ::= #"[ \t
\r]*,[ \t
\r]*";
colon ::= #"[ \t
\r]*:[ \t
\r]*";
object_begin ::= #"\\\\{[ \t
\r]*";
object_end ::= #"[ \t
\r]*\\\\}";
array_begin ::= #"\\\\[[ \t
\r]*";
array_end ::= #"[ \t
\r]*\\\\]";
start ::= object_begin \'"min_length_str"\' colon start_min_length_str comma \'"max_length_str"\' colon start_max_length_str comma \'"pattern_str"\' colon start_pattern_str comma \'"combined_str"\' colon start_combined_str object_end;
start_combined_str ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4}){2,5}"\';
start_pattern_str ::= #\'"^[a-zA-Z0-9]+$"\';
start_max_length_str ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4}){0,10}"\';
start_min_length_str ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4}){3,}"\';
'''

snapshots['test_recursive_binary_tree_schema 1'] = '''integer ::= #"-?(0|[1-9]\\\\d*)";
number ::= #"-?(0|[1-9]\\\\d*)(\\\\.\\\\d+)?([eE][+-]?\\\\d+)?";
string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\';
Expand Down
11 changes: 11 additions & 0 deletions tests/test_grammar_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ def test_pydantic_class(snapshot):
result = JsonExtractor("start", None,Test,lambda x:x).kbnf_definition
snapshot.assert_match(result)

def test_pydantic_string_constraints(snapshot):
class StringConstraints(formatron.schemas.pydantic.ClassSchema):
min_length_str: typing.Annotated[str, Field(min_length=3)]
max_length_str: typing.Annotated[str, Field(max_length=10)]
pattern_str: typing.Annotated[str, Field(pattern=r'^[a-zA-Z0-9]+$')]
combined_str: typing.Annotated[str, Field(min_length=2, max_length=5)]

result = JsonExtractor("start", None, StringConstraints, lambda x: x).kbnf_definition
snapshot.assert_match(result)


def test_json_schema(snapshot):
schema = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
Expand Down

0 comments on commit 75f7636

Please sign in to comment.