From 75f7636ed8fb24f2fea35c13d0a77df2b58b1abc Mon Sep 17 00:00:00 2001 From: Huanghe Date: Thu, 10 Oct 2024 03:02:27 +0000 Subject: [PATCH] patterns, minlen,maxlen for pydantic --- src/formatron/formats/json.py | 42 ++++++++++++++++++++++-- src/formatron/schemas/pydantic.py | 13 ++++++-- src/formatron/schemas/schema.py | 12 +++++++ tests/snapshots/snap_test_grammar_gen.py | 35 ++++++++++++++++++-- tests/test_grammar_gen.py | 11 +++++++ 5 files changed, 105 insertions(+), 8 deletions(-) diff --git a/src/formatron/formats/json.py b/src/formatron/formats/json.py index 491e9134..263fdbd5 100644 --- a/src/formatron/formats/json.py +++ b/src/formatron/formats/json.py @@ -68,10 +68,39 @@ def schema(current: typing.Type, nonterminal: str): def field_info(current: typing.Type, nonterminal: str): if isinstance(current, schemas.schema.FieldInfo): + annotation = current.annotation if current.required: - return "", [(current.annotation, nonterminal)] + return "", [(annotation, nonterminal)] new_nonterminal = f"{nonterminal}_required" - return f"{nonterminal} ::= {new_nonterminal}?;\n", [(current.annotation, new_nonterminal)] + return f"{nonterminal} ::= {new_nonterminal}?;\n", [(annotation, new_nonterminal)] + return None + + def string_metadata(current: typing.Type, nonterminal: str): + min_length = current.metadata.get("min_length") + max_length = current.metadata.get("max_length") + pattern = current.metadata.get("pattern") + if pattern: + assert not (min_length or max_length), "pattern is mutually exclusive with min_length and max_length" + repetition = None + if min_length is not None and max_length is None: + repetition = f"{{{min_length},}}" + elif min_length is None and max_length is not None: + repetition = f"{{0,{max_length}}}" + elif min_length is not None and max_length is not None: + repetition = f"{{{min_length},{max_length}}}" + if repetition is not None: + return fr"""{nonterminal} ::= #'"([^\\\\"\u0000-\u001f]|\\\\["\\\\bfnrt/]|\\\\u[0-9A-Fa-f]{{4}}){repetition}"'; +""", [] + if pattern is not None: + pattern = pattern.replace("'", "\\'") + return f"""{nonterminal} ::= #'"{pattern}"';\n""", [] + + def metadata(current: typing.Type, nonterminal: str): + if isinstance(current, schemas.schema.TypeWithMetadata): + if not current.metadata: + return "", [(current.type, nonterminal)] + if isinstance(current.type, type) and issubclass(current.type, str): + return string_metadata(current, nonterminal) return None def builtin_list(current: typing.Type, nonterminal: str): @@ -189,6 +218,7 @@ def builtin_simple_types(current: typing.Type, nonterminal: str): register_generate_nonterminal_def(builtin_simple_types) register_generate_nonterminal_def(schema) register_generate_nonterminal_def(field_info) + register_generate_nonterminal_def(metadata) register_generate_nonterminal_def(builtin_tuple) register_generate_nonterminal_def(builtin_literal) register_generate_nonterminal_def(builtin_union) @@ -255,7 +285,13 @@ def __init__(self, nonterminal: str, capture_name: typing.Optional[str], schema: - bool - int - float - - string + - str + - with min_length, max_length and pattern constraints + - length is measured in UTF-8 character number + - *Warning*: too large difference between min_length and max_length can lead to enormous memory consumption! + - pattern is mutually exclusive with min_length and max_length + - pattern will be compiled to a regular expression so all caveats of regular expressions apply + - pattern currently is automatically anchored at both ends - NoneType - typing.Any - Subclasses of collections.abc.Mapping[str,T] and typing.Mapping[str,T] where T is a supported type, diff --git a/src/formatron/schemas/pydantic.py b/src/formatron/schemas/pydantic.py index 33fdb5b8..16a5f1a6 100644 --- a/src/formatron/schemas/pydantic.py +++ b/src/formatron/schemas/pydantic.py @@ -8,7 +8,7 @@ import pydantic.fields from pydantic import BaseModel, validate_call, ConfigDict, Field -from formatron.schemas.schema import FieldInfo, Schema +from formatron.schemas.schema import FieldInfo, Schema, TypeWithMetadata class FieldInfo(FieldInfo): @@ -22,10 +22,19 @@ def __init__(self, field: pydantic.fields.FieldInfo): Initialize the field information. """ self._field = field + self._annotation = field.annotation + if field.metadata: + metadata = {} + for constraint in ["min_length", "max_length", "pattern"]: + value = next((getattr(m, constraint) for m in self._field.metadata if hasattr(m, constraint)), None) + if value is not None: + metadata[constraint] = value + if metadata: + self._annotation = TypeWithMetadata(self._annotation, metadata) @property def annotation(self) -> typing.Type[typing.Any] | None: - return self._field.annotation + return self._annotation @property def required(self) -> bool: diff --git a/src/formatron/schemas/schema.py b/src/formatron/schemas/schema.py index 7a9fa988..f7296755 100644 --- a/src/formatron/schemas/schema.py +++ b/src/formatron/schemas/schema.py @@ -22,6 +22,18 @@ def required(self) -> bool: """ pass +class TypeWithMetadata: + def __init__(self, type: typing.Type[typing.Any], metadata: dict[str, typing.Any]|None): + self._type = type + self._metadata = metadata + + @property + def type(self) -> typing.Type[typing.Any]: + return self._type + + @property + def metadata(self) -> dict[str, typing.Any]|None: + return self._metadata class Schema(abc.ABC): """ diff --git a/tests/snapshots/snap_test_grammar_gen.py b/tests/snapshots/snap_test_grammar_gen.py index c837858c..b3d9295c 100644 --- a/tests/snapshots/snap_test_grammar_gen.py +++ b/tests/snapshots/snap_test_grammar_gen.py @@ -36,9 +36,9 @@ start_concepts_value ::= string; start_related_queries ::= array_begin (start_related_queries_value (comma start_related_queries_value)*)? array_end; start_related_queries_value ::= start_related_queries_value_0 | start_related_queries_value_1; -start_related_queries_value_1 ::= string; -start_related_queries_value_0 ::= object_begin \'"foo"\' colon start_related_queries_value_0_foo object_end; -start_related_queries_value_0_foo ::= integer; +start_related_queries_value_1 ::= object_begin \'"foo"\' colon start_related_queries_value_1_foo object_end; +start_related_queries_value_1_foo ::= integer; +start_related_queries_value_0 ::= string; start_queries ::= array_begin (start_queries_value (comma start_queries_value)*)? array_end; start_queries_value ::= start_queries_value_0 | start_queries_value_1 | start_queries_value_2; start_queries_value_2 ::= boolean; @@ -193,6 +193,35 @@ start_value ::= integer; ''' +snapshots['test_pydantic_string_constraints 1'] = '''integer ::= #"-?(0|[1-9]\\\\d*)"; +number ::= #"-?(0|[1-9]\\\\d*)(\\\\.\\\\d+)?([eE][+-]?\\\\d+)?"; +string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\'; +boolean ::= "true"|"false"; +null ::= "null"; +array ::= array_begin (json_value (comma json_value)*)? array_end; +object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end; +json_value ::= number|string|boolean|null|array|object; +comma ::= #"[ \t +\r]*,[ \t +\r]*"; +colon ::= #"[ \t +\r]*:[ \t +\r]*"; +object_begin ::= #"\\\\{[ \t +\r]*"; +object_end ::= #"[ \t +\r]*\\\\}"; +array_begin ::= #"\\\\[[ \t +\r]*"; +array_end ::= #"[ \t +\r]*\\\\]"; +start ::= object_begin \'"min_length_str"\' colon start_min_length_str comma \'"max_length_str"\' colon start_max_length_str comma \'"pattern_str"\' colon start_pattern_str comma \'"combined_str"\' colon start_combined_str object_end; +start_combined_str ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4}){2,5}"\'; +start_pattern_str ::= #\'"^[a-zA-Z0-9]+$"\'; +start_max_length_str ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4}){0,10}"\'; +start_min_length_str ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4}){3,}"\'; +''' + snapshots['test_recursive_binary_tree_schema 1'] = '''integer ::= #"-?(0|[1-9]\\\\d*)"; number ::= #"-?(0|[1-9]\\\\d*)(\\\\.\\\\d+)?([eE][+-]?\\\\d+)?"; string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\'; diff --git a/tests/test_grammar_gen.py b/tests/test_grammar_gen.py index 9ae3b274..094c3a52 100644 --- a/tests/test_grammar_gen.py +++ b/tests/test_grammar_gen.py @@ -32,6 +32,17 @@ def test_pydantic_class(snapshot): result = JsonExtractor("start", None,Test,lambda x:x).kbnf_definition snapshot.assert_match(result) +def test_pydantic_string_constraints(snapshot): + class StringConstraints(formatron.schemas.pydantic.ClassSchema): + min_length_str: typing.Annotated[str, Field(min_length=3)] + max_length_str: typing.Annotated[str, Field(max_length=10)] + pattern_str: typing.Annotated[str, Field(pattern=r'^[a-zA-Z0-9]+$')] + combined_str: typing.Annotated[str, Field(min_length=2, max_length=5)] + + result = JsonExtractor("start", None, StringConstraints, lambda x: x).kbnf_definition + snapshot.assert_match(result) + + def test_json_schema(snapshot): schema = { "$schema": "https://json-schema.org/draft/2020-12/schema",