Skip to content

Commit

Permalink
add safe_subset argument to json_schema.to_regex, implement integer m…
Browse files Browse the repository at this point in the history
…inimum / maximum
  • Loading branch information
lapp0 committed Aug 31, 2024
1 parent 72377db commit d30b33c
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 37 deletions.
212 changes: 193 additions & 19 deletions outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect
import itertools
import json
import math
import re
import warnings
from typing import Callable, Optional, Tuple, Type, Union
Expand All @@ -18,14 +20,16 @@
NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
BOOLEAN = r"(true|false)"
NULL = r"null"
WHITESPACE = r"[ ]?"
WHITESPACE = r"[\n\t ]*"
SAFE_WHITESPACE = r"[ ]?"


type_to_regex = {
"string": STRING,
"integer": INTEGER,
"number": NUMBER,
"boolean": BOOLEAN,
"null": NULL,
"integer": INTEGER,
}

DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"'
Expand All @@ -41,7 +45,145 @@
}


def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None):
def get_subranges(minimum, maximum):
"""
Convert a range into a list of subranges which can fit into a pattern
E.g. minimum=123, maximum=456 cannot easily be made into a regex pattern
therefore, (123, 456) is converted to
[(123, 129), (130, 199), (200, 399), (400, 449), (450, 456)]
which can be converted in get_subrange_pattern() to
["12[3-9]", "(1[3-9][0-9]{1}", "[2-3][0-9]{2}", "4[0-4][0-9]{1}", "45[0-6]"]
"""
min_str = str(minimum).zfill(len(str(maximum)))
max_str = str(maximum)

# if only the last digit varies, its a valid subrange
if min_str[:-1] == max_str[:-1]:
return [(minimum, maximum)]

# calculate the shared prefix between minimum and maximum and left-truncate it for now
num_shared_prefix = len(
list(itertools.takewhile(lambda x: x[0] == x[1], zip(min_str, max_str)))
)
shared_min = min_str[num_shared_prefix:]
shared_max = max_str[num_shared_prefix:]
prefix = min_str[:num_shared_prefix]

# determine how many trailing digits back are valid [0-9]
# set first digit which doesn't qualify as the flex
# then combine: {prefix}{flex}[0-9]{count}
num_truncate = len(shared_min) - len(shared_min.rstrip("0")) + 1
child_max = int(prefix + shared_min[:-num_truncate] + "9" * num_truncate)
if child_max > maximum:
child_max = int(prefix + shared_max[0] + "0" * len(shared_max[1:])) - 1

if child_max == maximum:
return [(minimum, child_max)]
return [(minimum, child_max)] + get_subranges(child_max + 1, maximum)


def get_subrange_pattern(minimum, maximum):
"""Convert (200, 399) to ([2-3][0-9]{2})"""

max_str = str(maximum)
min_str = str(minimum).zfill(len(max_str))

last_range_zero = len(min_str) - re.search(r"[^0]|$", min_str[::-1]).start()
last_range_nine = len(max_str) - re.search(r"[^9]|$", max_str[::-1]).start()
full_range_start = max(last_range_zero, last_range_nine)

shared_prefix = min_str[: full_range_start - 1]
range_digit_min, range_digit_max = (
min_str[full_range_start - 1],
max_str[full_range_start - 1],
)

pattern = rf"{shared_prefix}[{range_digit_min}-{range_digit_max}]"

num_0_9_chars = len(max_str) - full_range_start
if num_0_9_chars:
pattern += rf"[0-9]{{{num_0_9_chars}}}"

return rf"({pattern})"


def get_positive_int_range_pattern(minimum, maximum):
assert minimum >= 0
assert maximum >= 0

if minimum == 0:
minimum = 1
explicit_zero = True
if maximum == 0:
maximum = 1
else:
explicit_zero = False

if maximum == float("inf"):
pseudo_maximum = 10 ** math.ceil(math.log10(minimum + 1)) - 1
pseudo_pattern = "|".join(
[
get_subrange_pattern(sub_min, sub_max)
for sub_min, sub_max in get_subranges(minimum, pseudo_maximum)
]
)
pattern = rf"([\d]{{{len(str(pseudo_maximum))+1},}}|{pseudo_pattern})"
else:
pattern = "|".join(
[
get_subrange_pattern(sub_min, sub_max)
for sub_min, sub_max in get_subranges(minimum, maximum)
]
)

if explicit_zero:
pattern = rf"(0|({pattern}))"

return pattern


def get_int_range_pattern(minimum=None, maximum=None):
"""
Create a pattern which matches all integers in range [minimum, maximum] *inclusive*
"""
if minimum is None:
minimum = -float("inf")
if maximum is None:
maximum = float("inf")

if (minimum, maximum) == (-float("inf"), float("inf")):
return INTEGER

assert minimum <= maximum

if minimum == maximum == 0:
pattern = "0"
elif minimum < 0 and maximum <= 0:
abs_pattern = get_positive_int_range_pattern(max(abs(maximum), 1), abs(minimum))
pattern = rf"-({abs_pattern})"
if maximum == 0:
pattern = rf"0|({pattern})"
elif minimum < 0 and maximum > 0:
minimum_pattern = get_positive_int_range_pattern(1, abs(minimum))
maximum_pattern = get_positive_int_range_pattern(0, maximum)
pattern = rf"(-({minimum_pattern}))|({maximum_pattern})"
elif minimum >= 0 and maximum >= 0:
pattern = get_positive_int_range_pattern(minimum, maximum)
else:
raise RuntimeError("This shouldn't occur, please open an issue")

return rf"({pattern})"


def get_safe_int():
"""10% larger than int64 range"""
return get_int_range_pattern(minimum=-int(1e19), maximum=int(1e19))


def build_regex_from_schema(
schema: str, whitespace_pattern: Optional[str] = None, safe_subset: bool = True
):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
Expand All @@ -60,6 +202,12 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
safe_subset
Use a subset of json schema which performs better with language models.
If you want to all the model to generate any json structure, set to False.
Changes the following:
- If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?")
- If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19]
Returns
-------
Expand All @@ -83,7 +231,7 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non
resolver = registry.resolver()

content = schema.contents
return to_regex(resolver, content, whitespace_pattern)
return to_regex(resolver, content, whitespace_pattern, safe_subset)


def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
Expand Down Expand Up @@ -173,7 +321,10 @@ def validate_quantifiers(


def to_regex(
resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None
resolver: Resolver,
instance: dict,
whitespace_pattern: Optional[str] = None,
safe_subset: bool = True,
):
"""Translate a JSON Schema instance into a regex that validates the schema.
Expand All @@ -196,11 +347,18 @@ def to_regex(
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
safe_subset
Use a subset of json schema which performs better with language models.
If you want to all the model to generate any json structure, set to False.
Changes the following:
- If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?")
- If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19]
"""

# set whitespace pattern
if whitespace_pattern is None:
whitespace_pattern = WHITESPACE
whitespace_pattern = SAFE_WHITESPACE if safe_subset else WHITESPACE

if instance == {}:
# JSON Schema Spec: Empty object means unconstrained, any json type is legal
Expand All @@ -213,7 +371,9 @@ def to_regex(
{"type": "array"},
{"type": "object"},
]
regexes = [to_regex(resolver, t, whitespace_pattern) for t in types]
regexes = [
to_regex(resolver, t, whitespace_pattern, safe_subset) for t in types
]
regexes = [rf"({r})" for r in regexes]
return rf"{'|'.join(regexes)}"

Expand All @@ -231,7 +391,7 @@ def to_regex(
last_required_pos = max([i for i, value in enumerate(is_required) if value])
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}'
subregex += to_regex(resolver, value, whitespace_pattern)
subregex += to_regex(resolver, value, whitespace_pattern, safe_subset)
if i < last_required_pos:
subregex = f"{subregex}{whitespace_pattern},"
elif i > last_required_pos:
Expand All @@ -245,7 +405,7 @@ def to_regex(
property_subregexes = []
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}'
subregex += to_regex(resolver, value, whitespace_pattern)
subregex += to_regex(resolver, value, whitespace_pattern, safe_subset)
property_subregexes.append(subregex)
possible_patterns = []
for i in range(len(property_subregexes)):
Expand All @@ -266,7 +426,8 @@ def to_regex(
# given subschemas.
elif "allOf" in instance:
subregexes = [
to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"]
to_regex(resolver, t, whitespace_pattern, safe_subset)
for t in instance["allOf"]
]
subregexes_str = [f"{subregex}" for subregex in subregexes]
return rf"({''.join(subregexes_str)})"
Expand All @@ -275,15 +436,17 @@ def to_regex(
# any (one or more) of the given subschemas.
elif "anyOf" in instance:
subregexes = [
to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"]
to_regex(resolver, t, whitespace_pattern, safe_subset)
for t in instance["anyOf"]
]
return rf"({'|'.join(subregexes)})"

# To validate against oneOf, the given data must be valid against exactly
# one of the given subschemas.
elif "oneOf" in instance:
subregexes = [
to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"]
to_regex(resolver, t, whitespace_pattern, safe_subset)
for t in instance["oneOf"]
]

xor_patterns = [f"(?:{subregex})" for subregex in subregexes]
Expand All @@ -293,7 +456,8 @@ def to_regex(
# Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx
elif "prefixItems" in instance:
element_patterns = [
to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"]
to_regex(resolver, t, whitespace_pattern, safe_subset)
for t in instance["prefixItems"]
]
comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}"
tuple_inner = comma_split_pattern.join(element_patterns)
Expand Down Expand Up @@ -321,7 +485,7 @@ def to_regex(
elif "$ref" in instance:
path = f"{instance['$ref']}"
instance = resolver.lookup(path).contents
return to_regex(resolver, instance, whitespace_pattern)
return to_regex(resolver, instance, whitespace_pattern, safe_subset)

# The type keyword may either be a string or an array:
# - If it's a string, it is the name of one of the basic types.
Expand Down Expand Up @@ -366,6 +530,8 @@ def to_regex(
return type_to_regex["string"]

elif instance_type == "number":
# TODO: implement actualy json schema spec parameters: "maximum" and "minimum",
# should be easy through extending get_int_range_pattern
bounds = {
"minDigitsInteger",
"maxDigitsInteger",
Expand Down Expand Up @@ -405,12 +571,20 @@ def to_regex(
return type_to_regex["number"]

elif instance_type == "integer":
if "minDigits" in instance or "maxDigits" in instance:
min_digits, max_digits = validate_quantifiers(
instance.get("minDigits"), instance.get("maxDigits"), start_offset=1
# TODO: Remove errors eventulaly - these keys aren't part of json schema spec
if "maxDigits" in instance:
raise ValueError(
"'maxDigits' is not supported. Please use 'minimum' instead."
)
if "minDigits" in instance:
raise ValueError(
"'minDigits' is not supported. Please use 'minimum' instead."
)
return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})"
return type_to_regex["integer"]

maximum = instance.get("maximum", int(1e19) if safe_subset else None)
minimum = instance.get("minimum", -int(1e19) if safe_subset else None)

return get_int_range_pattern(minimum, maximum)

elif instance_type == "array":
num_repeats = _get_num_items_pattern(
Expand Down
Loading

0 comments on commit d30b33c

Please sign in to comment.