diff --git a/src/formatron/schemas/json_schema.py b/src/formatron/schemas/json_schema.py index 9e826ec..9ab26eb 100644 --- a/src/formatron/schemas/json_schema.py +++ b/src/formatron/schemas/json_schema.py @@ -158,11 +158,11 @@ def _infer_type(schema: dict[str, typing.Any], json_schema_id_to_schema: dict[in return _handle_anyOf(schema, json_schema_id_to_schema) obtained_type = _obtain_type(schema, json_schema_id_to_schema) args = typing.get_args(obtained_type) - if obtained_type is None or obtained_type is object or object in args: + if obtained_type is None: + obtained_type = typing.Union[str, float, int, bool, None, list[typing.Any]] + if obtained_type is object or object in args: obtained_type = _create_custom_type(obtained_type, schema, json_schema_id_to_schema) - if obtained_type is list and "items" in schema: - item_type = _convert_json_schema_to_our_schema(schema["items"], json_schema_id_to_schema) - obtained_type = list[item_type] + obtained_type = _handle_list_and_union(obtained_type, schema, json_schema_id_to_schema) json_schema_id_to_schema[id(schema)] = obtained_type return obtained_type @@ -188,14 +188,27 @@ def _create_custom_type(obtained_type: typing.Type|None, schema: dict[str, typin }) _counter += 1 - if obtained_type is None: - json_schema_id_to_schema[id(schema)] = typing.Union[str, float, int, bool, None, list[typing.Any], new_type] - elif object in typing.get_args(obtained_type): + if object in typing.get_args(obtained_type): json_schema_id_to_schema[id(schema)] = typing.Union[tuple(item for item in typing.get_args(obtained_type) if item is not object)+(new_type,)] else: json_schema_id_to_schema[id(schema)] = new_type return json_schema_id_to_schema[id(schema)] +def _handle_list_and_union(obtained_type: typing.Type, schema: dict[str, typing.Any], json_schema_id_to_schema: dict[int, typing.Type]) -> typing.Type: + """ + Handle cases where the obtained type is a list or a union containing a list. + """ + if obtained_type is list or (typing.get_origin(obtained_type) is typing.Union and list in typing.get_args(obtained_type)): + if "items" in schema: + item_type = _convert_json_schema_to_our_schema(schema["items"], json_schema_id_to_schema) + if obtained_type is list: + return list[item_type] + else: + args = typing.get_args(obtained_type) + new_args = tuple(list[item_type] if arg is list else arg for arg in args) + return typing.Union[new_args] + return obtained_type + def _obtain_type(schema: dict[str, typing.Any], json_schema_id_to_schema:dict[int, typing.Type]) -> typing.Type[typing.Any|None]: """ diff --git a/tests/snapshots/snap_test_grammar_gen.py b/tests/snapshots/snap_test_grammar_gen.py index 5c564db..e82899b 100644 --- a/tests/snapshots/snap_test_grammar_gen.py +++ b/tests/snapshots/snap_test_grammar_gen.py @@ -36,9 +36,9 @@ start_concepts_value ::= string; start_related_queries ::= array_begin (start_related_queries_value (comma start_related_queries_value)*)? array_end; start_related_queries_value ::= start_related_queries_value_0 | start_related_queries_value_1; -start_related_queries_value_1 ::= object_begin \'"foo"\' colon start_related_queries_value_1_foo object_end; -start_related_queries_value_1_foo ::= integer; -start_related_queries_value_0 ::= string; +start_related_queries_value_1 ::= string; +start_related_queries_value_0 ::= object_begin \'"foo"\' colon start_related_queries_value_0_foo object_end; +start_related_queries_value_0_foo ::= integer; start_queries ::= array_begin (start_queries_value (comma start_queries_value)*)? array_end; start_queries_value ::= start_queries_value_0 | start_queries_value_1 | start_queries_value_2; start_queries_value_2 ::= boolean; @@ -574,5 +574,6 @@ start_1_age ::= start_1_age_required?; start_1_age_required ::= integer; start_1_name ::= string; -start_0 ::= array; +start_0 ::= array_begin (start_0_value (comma start_0_value)*)? array_end; +start_0_value ::= string; '''