diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index dbae20281..2d947f372 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -91,6 +91,18 @@ def build_regex_from_object( return to_regex(resolver, content, whitespace_pattern) +def _get_num_items_pattern(min_items, max_items, whitespace_pattern): + # Helper function for arrays and objects + min_items = int(min_items or 0) + if max_items is None: + return rf"{{{max(min_items - 1, 0)},}}" + else: + max_items = int(max_items) + if max_items < 1: + return None + return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}" + + def to_regex( resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None ): @@ -266,18 +278,13 @@ def to_regex( return type_to_regex["integer"] elif instance_type == "array": - min_items = int(instance.get("minItems", "0")) - max_items = instance.get("maxItems", None) - max_items = max_items if max_items is None else int(max_items) + num_repeats = _get_num_items_pattern( + instance.get("minItems"), instance.get("maxItems"), whitespace_pattern + ) + if num_repeats is None: + return rf"\[{whitespace_pattern}\]" - if max_items is None: - num_repeats = rf"{{{max(min_items - 1, 0)},}}" - else: - if max_items < 1: - return rf"\[{whitespace_pattern}\]" - num_repeats = rf"{{{max(min_items - 1, 0)},{max_items - 1}}}" - - allow_empty = "?" if min_items == 0 else "" + allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else "" if "items" in instance: items_regex = to_regex(resolver, instance["items"], whitespace_pattern) @@ -296,6 +303,39 @@ def to_regex( regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}){allow_empty}{whitespace_pattern}\]" + elif instance_type == "object": + # pattern for json object with values defined by instance["additionalProperties"] + # enforces value type constraints recursively, "minProperties", and "maxProperties" + # doesn't enforce "required", "dependencies", "propertyNames" "any/all/on Of" + num_repeats = _get_num_items_pattern( + instance.get("minProperties"), + instance.get("maxProperties"), + whitespace_pattern, + ) + if num_repeats is None: + return rf"\{{{whitespace_pattern}\}}" + + allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else "" + + value_pattern = to_regex( + resolver, instance["additionalProperties"], whitespace_pattern + ) + key_value_pattern = ( + f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" + ) + key_value_successor_pattern = ( + f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}" + ) + multiple_key_value_pattern = f"({key_value_pattern}({key_value_successor_pattern}){num_repeats}){allow_empty}" + + return ( + r"\{" + + whitespace_pattern + + multiple_key_value_pattern + + whitespace_pattern + + r"\}" + ) + elif instance_type == "boolean": return type_to_regex["boolean"] diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 833b3e884..5afc19a80 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -234,6 +234,64 @@ def test_match_number(pattern, does_match): rf"\[{WHITESPACE}\]", [("[1]", False), ("[]", True), ("[1,2,3]", False), ("[1,2,3,4]", False)], ), + # object + ( + { + "title": "TestSchema", + "type": "object", + "properties": { + "test_dict": { + "title": "Test Dict", + "additionalProperties": {"type": "string"}, + "type": "object", + } + }, + "required": ["test_dict"], + }, + rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + [ + ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), + ("""{ "test_dict":{"foo":"bar"\n}}""", True), + ("""{ "test_dict":{}}""", True), + ("""{ "WRONG_KEY":{}}""", False), + ("""{ "test_dict":{"wrong_type" 1}}""", False), + ], + ), + # object containing object + ( + { + "title": "TestSchema", + "type": "object", + "properties": { + "test_dict": { + "title": "Test Dict", + "additionalProperties": { + "additionalProperties": {"type": "integer"}, + "type": "object", + }, + "type": "object", + } + }, + "required": ["test_dict"], + }, + rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + [ + ( + """{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""", + True, + ), + ( + """{"test_dict": {"anykey": {"anykey": 123}, "anykey2": {"bif": 456}}}""", + True, + ), + ("""{"test_dict": {}}""", True), + ("""{"test_dict": {"dict of empty dicts are ok": {} }}""", True), + ( + """{"test_dict": {"anykey": {"ONLY Dict[Dict]": 123}, "No Dict[int]" 1: }}""", + False, + ), + ], + ), # oneOf ( { @@ -464,6 +522,8 @@ def test_match(schema, regex, examples): for string, does_match in examples: match = re.fullmatch(test_regex, string) if does_match: + if match is None: + raise ValueError(f"Expected match for '{string}'") assert match[0] == string assert match.span() == (0, len(string)) else: