Skip to content

Commit

Permalink
min items and max items pydantic and json schema
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Oct 12, 2024
1 parent 7569d7c commit 3d11ba7
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 15 deletions.
76 changes: 72 additions & 4 deletions src/formatron/formats/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,91 @@ def number_metadata(current: typing.Type, nonterminal: str):
return f"""{nonterminal} ::= #'{prefix}[1-9][0-9]*(\\.[0-9]+)?([eE][+-]?[0-9]+)?';\n""", []

raise ValueError(f"{current.type.__name__} metadata {current.metadata} is not supported in json_generators!")

def sequence_metadata(current: typing.Type, nonterminal: str):
min_items = current.metadata.get("min_length")
max_items = current.metadata.get("max_length")
if min_items is not None or max_items is not None:
new_nonterminal = f"{nonterminal}_item"
ebnf_rules = []

if min_items is None:
min_items = 0
if min_items == 0 and max_items is None: # no special handling needed
return "", [(current.type, new_nonterminal)]
if max_items is None:
min_items_part = ' comma '.join([new_nonterminal] * (min_items - 1))
ebnf_rules.append(f"{nonterminal} ::= array_begin {min_items_part} comma {new_nonterminal}+ array_end;")
elif min_items == 0:
for i in range(min_items, max_items + 1):
items = ' comma '.join([new_nonterminal] * i)
ebnf_rules.append(f"{nonterminal} ::= array_begin {items} array_end;")
else:
min_items_part = ' comma '.join([new_nonterminal] * min_items)
ebnf_rules.append(f"{nonterminal}_min ::= {min_items_part};")
for i in range(1, max_items + 1 - min_items):
items = ' comma '.join([new_nonterminal] * i)
ebnf_rules.append(f"{nonterminal} ::= array_begin {nonterminal}_min comma {items} array_end;")
# Handle the item type
args = typing.get_args(current.type)
if args:
item_type = args[0]
else:
# If args is empty, default to Any
item_type = typing.Any
return "\n".join(ebnf_rules) + "\n", [(item_type, new_nonterminal)]
return None

def is_sequence_like(current: typing.Type) -> bool:
"""
Check if the given type is sequence-like.
This function returns True for:
- typing.Sequence
- typing.List
- typing.Tuple
- Any subclass of collections.abc.Sequence
- list
- tuple
Args:
current: The type to check.
Returns:
bool: True if the type is sequence-like, False otherwise.
"""
original = typing.get_origin(current)
if original is None:
original = current
return (
original is typing.Sequence or
original is typing.List or
original is typing.Tuple or
(isinstance(original, type) and (issubclass(original, collections.abc.Sequence) or
issubclass(original, list) or
issubclass(original, tuple)))
)

def metadata(current: typing.Type, nonterminal: str):
if isinstance(current, schemas.schema.TypeWithMetadata):
original = typing.get_origin(current.type)
if original is None:
original = current.type
if not current.metadata:
return "", [(current.type, nonterminal)]
if isinstance(current.type, type) and issubclass(current.type, str):
return string_metadata(current, nonterminal)
elif isinstance(current.type, type) and issubclass(current.type, (int, float)):
return number_metadata(current, nonterminal)
elif is_sequence_like(original):
return sequence_metadata(current, nonterminal)
return None

def builtin_list(current: typing.Type, nonterminal: str):
def builtin_sequence(current: typing.Type, nonterminal: str):
original = typing.get_origin(current)
if original is None:
original = current
if original is typing.Sequence or isinstance(original, type) \
and issubclass(original, collections.abc.Sequence):
if is_sequence_like(original):
new_nonterminal = f"{nonterminal}_value"
annotation = typing.get_args(current)
if not annotation:
Expand Down Expand Up @@ -245,7 +313,7 @@ def builtin_simple_types(current: typing.Type, nonterminal: str):
register_generate_nonterminal_def(builtin_tuple)
register_generate_nonterminal_def(builtin_literal)
register_generate_nonterminal_def(builtin_union)
register_generate_nonterminal_def(builtin_list)
register_generate_nonterminal_def(builtin_sequence)
register_generate_nonterminal_def(builtin_dict)

def _generate_kbnf_grammar(schema: schemas.schema.Schema|collections.abc.Sequence, start_nonterminal: str) -> str:
Expand Down
7 changes: 7 additions & 0 deletions src/formatron/schemas/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ def _handle_list_metadata(obtained_type: typing.Type, schema: dict[str, typing.A
"""
if "items" in schema:
item_type = _convert_json_schema_to_our_schema(schema["items"], json_schema_id_to_schema)
metadata = {}
if "minItems" in schema:
metadata["min_length"] = schema["minItems"]
if "maxItems" in schema:
metadata["max_length"] = schema["maxItems"]
if metadata:
return schemas.schema.TypeWithMetadata(list, metadata)
return list[item_type]
return obtained_type

Expand Down
108 changes: 97 additions & 11 deletions tests/snapshots/snap_test_grammar_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,49 @@
start_inStock ::= start_inStock_required?;
start_inStock_required ::= boolean;
start_tags ::= start_tags_required?;
start_tags_required ::= array_begin (start_tags_required_value (comma start_tags_required_value)*)? array_end;
start_tags_required_value ::= start_tags_required_value_0 | start_tags_required_value_1;
start_tags_required_value_1 ::= number;
start_tags_required_value_0 ::= string;
start_tags_required ::= array_begin comma start_tags_required_item+ array_end;
start_tags_required_item ::= json_value;
start_price ::= #'0|[1-9][0-9]*(\\.[0-9]+)?([eE][+-]?[0-9]+)?';
start_name ::= start_tags_required_value;
start_name ::= start_name_0 | start_name_1;
start_name_1 ::= number;
start_name_0 ::= string;
'''

snapshots['test_json_schema_array_min_max_items_constraints 1'] = '''integer ::= #"-?(0|[1-9][0-9]*)";
number ::= #"-?(0|[1-9][0-9]*)(\\\\.[0-9]+)?([eE][+-]?[0-9]+)?";
string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\';
boolean ::= "true"|"false";
null ::= "null";
array ::= array_begin (json_value (comma json_value)*)? array_end;
object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end;
json_value ::= number|string|boolean|null|array|object;
comma ::= #"[ \t
\r]*,[ \t
\r]*";
colon ::= #"[ \t
\r]*:[ \t
\r]*";
object_begin ::= #"\\\\{[ \t
\r]*";
object_end ::= #"[ \t
\r]*\\\\}";
array_begin ::= #"\\\\[[ \t
\r]*";
array_end ::= #"[ \t
\r]*\\\\]";
start ::= object_begin \'"min_items_array"\' colon start_min_items_array comma \'"max_items_array"\' colon start_max_items_array comma \'"min_max_items_array"\' colon start_min_max_items_array object_end;
start_min_max_items_array_min ::= start_min_max_items_array_item;
start_min_max_items_array ::= array_begin start_min_max_items_array_min comma start_min_max_items_array_item array_end;
start_min_max_items_array ::= array_begin start_min_max_items_array_min comma start_min_max_items_array_item comma start_min_max_items_array_item array_end;
start_min_max_items_array ::= array_begin start_min_max_items_array_min comma start_min_max_items_array_item comma start_min_max_items_array_item comma start_min_max_items_array_item array_end;
start_min_max_items_array_item ::= json_value;
start_max_items_array ::= array_begin array_end;
start_max_items_array ::= array_begin start_max_items_array_item array_end;
start_max_items_array ::= array_begin start_max_items_array_item comma start_max_items_array_item array_end;
start_max_items_array ::= array_begin start_max_items_array_item comma start_max_items_array_item comma start_max_items_array_item array_end;
start_max_items_array_item ::= start_min_max_items_array_item;
start_min_items_array ::= array_begin start_min_items_array_item comma start_min_items_array_item+ array_end;
start_min_items_array_item ::= start_min_max_items_array_item;
'''

snapshots['test_json_schema_integer_constraints 1'] = '''integer ::= #"-?(0|[1-9][0-9]*)";
Expand Down Expand Up @@ -317,6 +354,59 @@
start_gt_int ::= #'[1-9][0-9]*';
'''

snapshots['test_pydantic_sequence_constraints 1'] = '''integer ::= #"-?(0|[1-9][0-9]*)";
number ::= #"-?(0|[1-9][0-9]*)(\\\\.[0-9]+)?([eE][+-]?[0-9]+)?";
string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\';
boolean ::= "true"|"false";
null ::= "null";
array ::= array_begin (json_value (comma json_value)*)? array_end;
object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end;
json_value ::= number|string|boolean|null|array|object;
comma ::= #"[ \t
\r]*,[ \t
\r]*";
colon ::= #"[ \t
\r]*:[ \t
\r]*";
object_begin ::= #"\\\\{[ \t
\r]*";
object_end ::= #"[ \t
\r]*\\\\}";
array_begin ::= #"\\\\[[ \t
\r]*";
array_end ::= #"[ \t
\r]*\\\\]";
start ::= object_begin \'"min_2_list"\' colon start_min_2_list comma \'"max_5_list"\' colon start_max_5_list comma \'"min_1_max_3_list"\' colon start_min_1_max_3_list comma \'"min_2_tuple"\' colon start_min_2_tuple comma \'"max_5_tuple"\' colon start_max_5_tuple comma \'"min_1_max_3_tuple"\' colon start_min_1_max_3_tuple comma \'"empty_list"\' colon start_empty_list object_end;
start_empty_list_item ::= array_begin (start_empty_list_item_value (comma start_empty_list_item_value)*)? array_end;
start_empty_list_item_value ::= json_value;
start_min_1_max_3_tuple_min ::= start_min_1_max_3_tuple_item;
start_min_1_max_3_tuple ::= array_begin start_min_1_max_3_tuple_min comma start_min_1_max_3_tuple_item array_end;
start_min_1_max_3_tuple ::= array_begin start_min_1_max_3_tuple_min comma start_min_1_max_3_tuple_item comma start_min_1_max_3_tuple_item array_end;
start_min_1_max_3_tuple_item ::= number;
start_max_5_tuple ::= array_begin array_end;
start_max_5_tuple ::= array_begin start_max_5_tuple_item array_end;
start_max_5_tuple ::= array_begin start_max_5_tuple_item comma start_max_5_tuple_item array_end;
start_max_5_tuple ::= array_begin start_max_5_tuple_item comma start_max_5_tuple_item comma start_max_5_tuple_item array_end;
start_max_5_tuple ::= array_begin start_max_5_tuple_item comma start_max_5_tuple_item comma start_max_5_tuple_item comma start_max_5_tuple_item array_end;
start_max_5_tuple ::= array_begin start_max_5_tuple_item comma start_max_5_tuple_item comma start_max_5_tuple_item comma start_max_5_tuple_item comma start_max_5_tuple_item array_end;
start_max_5_tuple_item ::= string;
start_min_2_tuple ::= array_begin start_min_2_tuple_item comma start_min_2_tuple_item+ array_end;
start_min_2_tuple_item ::= integer;
start_min_1_max_3_list_min ::= start_min_1_max_3_list_item;
start_min_1_max_3_list ::= array_begin start_min_1_max_3_list_min comma start_min_1_max_3_list_item array_end;
start_min_1_max_3_list ::= array_begin start_min_1_max_3_list_min comma start_min_1_max_3_list_item comma start_min_1_max_3_list_item array_end;
start_min_1_max_3_list_item ::= number;
start_max_5_list ::= array_begin array_end;
start_max_5_list ::= array_begin start_max_5_list_item array_end;
start_max_5_list ::= array_begin start_max_5_list_item comma start_max_5_list_item array_end;
start_max_5_list ::= array_begin start_max_5_list_item comma start_max_5_list_item comma start_max_5_list_item array_end;
start_max_5_list ::= array_begin start_max_5_list_item comma start_max_5_list_item comma start_max_5_list_item comma start_max_5_list_item array_end;
start_max_5_list ::= array_begin start_max_5_list_item comma start_max_5_list_item comma start_max_5_list_item comma start_max_5_list_item comma start_max_5_list_item array_end;
start_max_5_list_item ::= string;
start_min_2_list ::= array_begin start_min_2_list_item comma start_min_2_list_item+ array_end;
start_min_2_list_item ::= integer;
'''

snapshots['test_pydantic_string_constraints 1'] = '''integer ::= #"-?(0|[1-9][0-9]*)";
number ::= #"-?(0|[1-9][0-9]*)(\\\\.[0-9]+)?([eE][+-]?[0-9]+)?";
string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\';
Expand Down Expand Up @@ -692,12 +782,8 @@
\r]*";
array_end ::= #"[ \t
\r]*\\\\]";
start ::= array_begin (start_value (comma start_value)*)? array_end;
start_value ::= object_begin \'"id"\' colon start_value_id comma \'"name"\' colon start_value_name comma \'"active"\' colon start_value_active object_end;
start_value_active ::= start_value_active_required?;
start_value_active_required ::= boolean;
start_value_name ::= string;
start_value_id ::= integer;
start ::= array_begin comma start_item+ array_end;
start_item ::= json_value;
'''

snapshots['test_schema_with_union_array_object 1'] = '''integer ::= #"-?(0|[1-9][0-9]*)";
Expand Down
43 changes: 43 additions & 0 deletions tests/test_grammar_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ class FloatConstraints(formatron.schemas.pydantic.ClassSchema):
result = JsonExtractor("start", None, FloatConstraints, lambda x: x).kbnf_definition
snapshot.assert_match(result)

def test_pydantic_sequence_constraints(snapshot):
class SequenceConstraints(formatron.schemas.pydantic.ClassSchema):
min_2_list: typing.Annotated[typing.List[int], Field(min_length=2)]
max_5_list: typing.Annotated[typing.List[str], Field(max_length=5)]
min_1_max_3_list: typing.Annotated[typing.List[float], Field(min_length=1, max_length=3)]
min_2_tuple: typing.Annotated[typing.Tuple[int, ...], Field(min_length=2)]
max_5_tuple: typing.Annotated[typing.Tuple[str, ...], Field(max_length=5)]
min_1_max_3_tuple: typing.Annotated[typing.Tuple[float, ...], Field(min_length=1, max_length=3)]
empty_list: typing.Annotated[typing.List[typing.Any], Field(min_length=0)]

result = JsonExtractor("start", None, SequenceConstraints, lambda x: x).kbnf_definition
snapshot.assert_match(result)


def test_json_schema_integer_constraints(snapshot):
schema = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
Expand Down Expand Up @@ -128,6 +142,35 @@ def test_json_schema_number_constraints(snapshot):
result = JsonExtractor("start", None, schema, lambda x: x).kbnf_definition
snapshot.assert_match(result)

def test_json_schema_array_min_max_items_constraints(snapshot):
schema = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "https://example.com/array-constraints-schema.json",
"type": "object",
"properties": {
"min_items_array": {
"type": "array",
"items": {"type": "string"},
"minItems": 2
},
"max_items_array": {
"type": "array",
"items": {"type": "number"},
"maxItems": 3
},
"min_max_items_array": {
"type": "array",
"items": {"type": "boolean"},
"minItems": 1,
"maxItems": 4
},
},
"required": ["min_items_array", "max_items_array", "min_max_items_array"]
}
schema = json_schema.create_schema(schema)
result = JsonExtractor("start", None, schema, lambda x: x).kbnf_definition
snapshot.assert_match(result)


def test_json_schema(snapshot):
schema = {
Expand Down

0 comments on commit 3d11ba7

Please sign in to comment.