Skip to content

Commit

Permalink
Resolve objects in JSON Schema (#664)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Lapp <[email protected]>
  • Loading branch information
lapp0 and Andrew Lapp authored Feb 16, 2024
1 parent 7480837 commit f1d2f78
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 11 deletions.
62 changes: 51 additions & 11 deletions outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def build_regex_from_object(
return to_regex(resolver, content, whitespace_pattern)


def _get_num_items_pattern(min_items, max_items, whitespace_pattern):
# Helper function for arrays and objects
min_items = int(min_items or 0)
if max_items is None:
return rf"{{{max(min_items - 1, 0)},}}"
else:
max_items = int(max_items)
if max_items < 1:
return None
return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}"


def to_regex(
resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None
):
Expand Down Expand Up @@ -266,18 +278,13 @@ def to_regex(
return type_to_regex["integer"]

elif instance_type == "array":
min_items = int(instance.get("minItems", "0"))
max_items = instance.get("maxItems", None)
max_items = max_items if max_items is None else int(max_items)
num_repeats = _get_num_items_pattern(
instance.get("minItems"), instance.get("maxItems"), whitespace_pattern
)
if num_repeats is None:
return rf"\[{whitespace_pattern}\]"

if max_items is None:
num_repeats = rf"{{{max(min_items - 1, 0)},}}"
else:
if max_items < 1:
return rf"\[{whitespace_pattern}\]"
num_repeats = rf"{{{max(min_items - 1, 0)},{max_items - 1}}}"

allow_empty = "?" if min_items == 0 else ""
allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else ""

if "items" in instance:
items_regex = to_regex(resolver, instance["items"], whitespace_pattern)
Expand All @@ -296,6 +303,39 @@ def to_regex(
regexes = [to_regex(resolver, t, whitespace_pattern) for t in types]
return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}){allow_empty}{whitespace_pattern}\]"

elif instance_type == "object":
# pattern for json object with values defined by instance["additionalProperties"]
# enforces value type constraints recursively, "minProperties", and "maxProperties"
# doesn't enforce "required", "dependencies", "propertyNames" "any/all/on Of"
num_repeats = _get_num_items_pattern(
instance.get("minProperties"),
instance.get("maxProperties"),
whitespace_pattern,
)
if num_repeats is None:
return rf"\{{{whitespace_pattern}\}}"

allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else ""

value_pattern = to_regex(
resolver, instance["additionalProperties"], whitespace_pattern
)
key_value_pattern = (
f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}"
)
key_value_successor_pattern = (
f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}"
)
multiple_key_value_pattern = f"({key_value_pattern}({key_value_successor_pattern}){num_repeats}){allow_empty}"

return (
r"\{"
+ whitespace_pattern
+ multiple_key_value_pattern
+ whitespace_pattern
+ r"\}"
)

elif instance_type == "boolean":
return type_to_regex["boolean"]

Expand Down
60 changes: 60 additions & 0 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,64 @@ def test_match_number(pattern, does_match):
rf"\[{WHITESPACE}\]",
[("[1]", False), ("[]", True), ("[1,2,3]", False), ("[1,2,3,4]", False)],
),
# object
(
{
"title": "TestSchema",
"type": "object",
"properties": {
"test_dict": {
"title": "Test Dict",
"additionalProperties": {"type": "string"},
"type": "object",
}
},
"required": ["test_dict"],
},
rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""",
[
("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True),
("""{ "test_dict":{"foo":"bar"\n}}""", True),
("""{ "test_dict":{}}""", True),
("""{ "WRONG_KEY":{}}""", False),
("""{ "test_dict":{"wrong_type" 1}}""", False),
],
),
# object containing object
(
{
"title": "TestSchema",
"type": "object",
"properties": {
"test_dict": {
"title": "Test Dict",
"additionalProperties": {
"additionalProperties": {"type": "integer"},
"type": "object",
},
"type": "object",
}
},
"required": ["test_dict"],
},
rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""",
[
(
"""{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""",
True,
),
(
"""{"test_dict": {"anykey": {"anykey": 123}, "anykey2": {"bif": 456}}}""",
True,
),
("""{"test_dict": {}}""", True),
("""{"test_dict": {"dict of empty dicts are ok": {} }}""", True),
(
"""{"test_dict": {"anykey": {"ONLY Dict[Dict]": 123}, "No Dict[int]" 1: }}""",
False,
),
],
),
# oneOf
(
{
Expand Down Expand Up @@ -464,6 +522,8 @@ def test_match(schema, regex, examples):
for string, does_match in examples:
match = re.fullmatch(test_regex, string)
if does_match:
if match is None:
raise ValueError(f"Expected match for '{string}'")
assert match[0] == string
assert match.span() == (0, len(string))
else:
Expand Down

0 comments on commit f1d2f78

Please sign in to comment.