diff --git a/docs/reference/generation/json.md b/docs/reference/generation/json.md index da9f14729..77bc06504 100644 --- a/docs/reference/generation/json.md +++ b/docs/reference/generation/json.md @@ -34,14 +34,21 @@ print(result) # User(name="John", last_name="Doe", id=11) ``` -!!! Note "JSON and whitespaces" +!!! Note "JSON and unlimited patterns" - By default Outlines prevents the model from generating json with syntactic newlines, tabs, or multiple spaces. The default `whitespace_pattern` is `r"[ ]?"`. Small models tend to enter an infinite repetition loop if the `whitespace_pattern` allows infinite spacing. If you would like to allow the model to generate multiple tabs, newlines, and spaces, you can set the whitespace pattern as follows: + By default Outlines prevents the model from generating json with syntactic newlines, tabs, or multiple spaces. Additionally by default strings cannot be longer than 256 characters, and integers are bound between -1e19 and 1e19. Small models tend to enter an infinite repetition loop if JSON schema generation isn't constrained. If you would like to allow the model to generate multiple tabs, newlines, and spaces, you can set the whitespace pattern as follows: ```python generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*") ``` + Or you can remove all implicit constraints on json generation (whitespace, integer, and string) with + + ```python + generator = generate.json(model, User, safe_subset=False) + ``` + + !!! Note "Performance" `generation.json` computes an index that helps Outlines guide generation. This can take some time, but only needs to be done once. If you want to generate several times with the same schema make sure that you only call `generate.json` once. diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 98d2de59c..646e83509 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -1,8 +1,10 @@ import inspect +import itertools import json +import math import re import warnings -from typing import Callable, Optional, Tuple, Type, Union +from typing import Callable, List, Optional, Tuple, Type, Union from jsonschema.protocols import Validator from pydantic import BaseModel, create_model @@ -18,14 +20,21 @@ NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" BOOLEAN = r"(true|false)" NULL = r"null" -WHITESPACE = r"[ ]?" +WHITESPACE = r"[\n\t ]*" +SAFE_WHITESPACE = r"[ ]?" +SAFE_INT_MAX = int(1e19) +SAFE_INT_MIN = int(-1e19) +SAFE_STR_MAX_LEN = 256 + + +# TODO: Deprecate? This isn't used anywhere internally type_to_regex = { "string": STRING, - "integer": INTEGER, "number": NUMBER, "boolean": BOOLEAN, "null": NULL, + "integer": INTEGER, } DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"' @@ -41,7 +50,9 @@ } -def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): +def build_regex_from_schema( + schema: str, whitespace_pattern: Optional[str] = None, safe_subset: bool = True +): """Turn a JSON schema into a regex that matches any JSON object that follows this schema. @@ -60,6 +71,13 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + safe_subset + Use a subset of json schema which performs better with language models. + If you want to all the model to generate any json structure, set to False. + Changes the following: + - If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?") + - If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19] + - If unconstrained string is used, constrain it to max of 256 characters Returns ------- @@ -83,7 +101,7 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non resolver = registry.resolver() content = schema.contents - return to_regex(resolver, content, whitespace_pattern) + return to_regex(resolver, content, whitespace_pattern, safe_subset) def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: @@ -173,7 +191,10 @@ def validate_quantifiers( def to_regex( - resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None + resolver: Resolver, + instance: dict, + whitespace_pattern: Optional[str] = None, + safe_subset: bool = True, ): """Translate a JSON Schema instance into a regex that validates the schema. @@ -196,11 +217,18 @@ def to_regex( whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + safe_subset + Use a subset of json schema which performs better with language models. + If you want to all the model to generate any json structure, set to False. + Changes the following: + - If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?") + - If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19] + - If unconstrained string is used, constrain it to max of 256 characters """ # set whitespace pattern if whitespace_pattern is None: - whitespace_pattern = WHITESPACE + whitespace_pattern = SAFE_WHITESPACE if safe_subset else WHITESPACE if instance == {}: # JSON Schema Spec: Empty object means unconstrained, any json type is legal @@ -213,7 +241,9 @@ def to_regex( {"type": "array"}, {"type": "object"}, ] - regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] + regexes = [ + to_regex(resolver, t, whitespace_pattern, safe_subset) for t in types + ] regexes = [rf"({r})" for r in regexes] return rf"{'|'.join(regexes)}" @@ -231,7 +261,7 @@ def to_regex( last_required_pos = max([i for i, value in enumerate(is_required) if value]) for i, (name, value) in enumerate(properties.items()): subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) + subregex += to_regex(resolver, value, whitespace_pattern, safe_subset) if i < last_required_pos: subregex = f"{subregex}{whitespace_pattern}," elif i > last_required_pos: @@ -245,7 +275,7 @@ def to_regex( property_subregexes = [] for i, (name, value) in enumerate(properties.items()): subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) + subregex += to_regex(resolver, value, whitespace_pattern, safe_subset) property_subregexes.append(subregex) possible_patterns = [] for i in range(len(property_subregexes)): @@ -266,7 +296,8 @@ def to_regex( # given subschemas. elif "allOf" in instance: subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"] + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in instance["allOf"] ] subregexes_str = [f"{subregex}" for subregex in subregexes] return rf"({''.join(subregexes_str)})" @@ -275,7 +306,8 @@ def to_regex( # any (one or more) of the given subschemas. elif "anyOf" in instance: subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"] + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in instance["anyOf"] ] return rf"({'|'.join(subregexes)})" @@ -283,7 +315,8 @@ def to_regex( # one of the given subschemas. elif "oneOf" in instance: subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in instance["oneOf"] ] xor_patterns = [f"(?:{subregex})" for subregex in subregexes] @@ -293,7 +326,8 @@ def to_regex( # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx elif "prefixItems" in instance: element_patterns = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"] + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in instance["prefixItems"] ] comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}" tuple_inner = comma_split_pattern.join(element_patterns) @@ -321,7 +355,7 @@ def to_regex( elif "$ref" in instance: path = f"{instance['$ref']}" instance = resolver.lookup(path).contents - return to_regex(resolver, instance, whitespace_pattern) + return to_regex(resolver, instance, whitespace_pattern, safe_subset) # The type keyword may either be a string or an array: # - If it's a string, it is the name of one of the basic types. @@ -332,16 +366,11 @@ def to_regex( instance_type = instance["type"] if instance_type == "string": if "maxLength" in instance or "minLength" in instance: - max_items = instance.get("maxLength", "") - min_items = instance.get("minLength", "") - try: - if int(max_items) < int(min_items): - raise ValueError( - "maxLength must be greater than or equal to minLength" - ) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume) - except ValueError: - pass - return f'"{STRING_INNER}{{{min_items},{max_items}}}"' + return get_str_pattern( + min_length=instance.get("minLength"), + max_length=instance.get("maxLength"), + safe_subset=safe_subset, + ) elif "pattern" in instance: pattern = instance["pattern"] if pattern[0] == "^" and pattern[-1] == "$": @@ -363,9 +392,11 @@ def to_regex( f"Format {format} is not supported by Outlines" ) else: - return type_to_regex["string"] + return get_str_pattern(safe_subset=safe_subset) elif instance_type == "number": + # TODO: implement actualy json schema spec parameters: "maximum" and "minimum", + # should be easy through extending get_int_range_pattern bounds = { "minDigitsInteger", "maxDigitsInteger", @@ -402,15 +433,24 @@ def to_regex( else "+" ) return rf"((-)?(0|[1-9][0-9]{integers_quantifier}))(\.[0-9]{fraction_quantifier})?([eE][+-][0-9]{exponent_quantifier})?" - return type_to_regex["number"] + return NUMBER elif instance_type == "integer": - if "minDigits" in instance or "maxDigits" in instance: - min_digits, max_digits = validate_quantifiers( - instance.get("minDigits"), instance.get("maxDigits"), start_offset=1 + # TODO: Remove errors eventulaly - these keys aren't part of json schema spec + if "maxDigits" in instance: + raise ValueError( + "'maxDigits' is not supported. Please use 'minimum' instead." + ) + if "minDigits" in instance: + raise ValueError( + "'minDigits' is not supported. Please use 'minimum' instead." ) - return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})" - return type_to_regex["integer"] + + return get_int_pattern( + minimum=instance.get("minimum"), + maximum=instance.get("maximum"), + safe_subset=safe_subset, + ) elif instance_type == "array": num_repeats = _get_num_items_pattern( @@ -422,7 +462,9 @@ def to_regex( allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else "" if "items" in instance: - items_regex = to_regex(resolver, instance["items"], whitespace_pattern) + items_regex = to_regex( + resolver, instance["items"], whitespace_pattern, safe_subset + ) return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]" else: # Here we need to make the choice to exclude generating list of objects @@ -441,7 +483,8 @@ def to_regex( legal_types.append({"type": "array", "depth": depth - 1}) regexes = [ - to_regex(resolver, t, whitespace_pattern) for t in legal_types + to_regex(resolver, t, whitespace_pattern, safe_subset) + for t in legal_types ] return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]" @@ -481,11 +524,12 @@ def to_regex( legal_types.append({"type": "array", "depth": depth - 1}) additional_properties = {"anyOf": legal_types} + key_pattern = get_str_pattern(safe_subset=safe_subset) value_pattern = to_regex( - resolver, additional_properties, whitespace_pattern + resolver, additional_properties, whitespace_pattern, safe_subset ) key_value_pattern = ( - f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" + f"{key_pattern}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" ) key_value_successor_pattern = ( f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}" @@ -501,17 +545,17 @@ def to_regex( ) elif instance_type == "boolean": - return type_to_regex["boolean"] + return BOOLEAN elif instance_type == "null": - return type_to_regex["null"] + return NULL elif isinstance(instance_type, list): # Here we need to make the choice to exclude generating an object # if the specification of the object is not give, even though a JSON # object that contains an object here would be valid under the specification. regexes = [ - to_regex(resolver, {"type": t}, whitespace_pattern) + to_regex(resolver, {"type": t}, whitespace_pattern, safe_subset) for t in instance_type if t != "object" ] @@ -550,3 +594,221 @@ def get_schema_from_signature(fn: Callable) -> str: model = create_model(fn_name, **arguments) return model.model_json_schema() + + +def get_subranges(minimum: int, maximum: int) -> List[Tuple]: + """ + Convert a range into a list of subranges which can fit into a pattern + + E.g. minimum=123, maximum=456 cannot easily be made into a regex pattern + therefore, (123, 456) is converted to + [(123, 129), (130, 199), (200, 399), (400, 449), (450, 456)] + which can be converted in get_subrange_pattern() to + ["12[3-9]", "(1[3-9][0-9]{1}", "[2-3][0-9]{2}", "4[0-4][0-9]{1}", "45[0-6]"] + """ + min_str = str(minimum).zfill(len(str(maximum))) + max_str = str(maximum) + + # if only the last digit varies, its a valid subrange + if min_str[:-1] == max_str[:-1]: + return [(minimum, maximum)] + + # calculate the shared prefix between minimum and maximum and left-truncate it for now + num_shared_prefix = len( + list(itertools.takewhile(lambda x: x[0] == x[1], zip(min_str, max_str))) + ) + shared_min = min_str[num_shared_prefix:] + shared_max = max_str[num_shared_prefix:] + prefix = min_str[:num_shared_prefix] + + # determine how many trailing digits back are valid [0-9] + # set first digit which doesn't qualify as the flex + # then combine: {prefix}{flex}[0-9]{count} + num_truncate = len(shared_min) - len(shared_min.rstrip("0")) + 1 + child_max = int(prefix + shared_min[:-num_truncate] + "9" * num_truncate) + if child_max > maximum: + child_max = int(prefix + shared_max[0] + "0" * len(shared_max[1:])) - 1 + + if child_max == maximum: + return [(minimum, child_max)] + return [(minimum, child_max)] + get_subranges(child_max + 1, maximum) + + +def get_subrange_pattern(minimum: int, maximum: int) -> str: + """ + Generates a regex pattern for a subrange where digits can be represented using character classes. + + This function creates a regex pattern for a given integer subrange where the digits can be + represented using character classes and quantifiers. It assumes that the range can be represented + by varying specific digits while others remain constant or within a simple range. + + For example: + - (200, 399) -> '([2-3][0-9]{2})' + - (310, 319) -> '(31[0-9])' + - (100, 189) -> '(1[0-8][0-9])' + + The function should only be called with ranges that can be represented in this way. + It does not handle ranges where digits do not align for simple character classes. + + Args: + minimum (int): The lower bound of the integer subrange. + maximum (int): The upper bound of the integer subrange. + + Returns: + str: A regex pattern string that matches all integers in the subrange. + """ + + max_str = str(maximum) + min_str = str(minimum).zfill(len(max_str)) + + last_range_zero = len(min_str) - re.search(r"[^0]|$", min_str[::-1]).start() # type: ignore + last_range_nine = len(max_str) - re.search(r"[^9]|$", max_str[::-1]).start() # type: ignore + if last_range_zero is None or last_range_nine is None: + raise RuntimeError(f"invalid string range: {minimum} to {maximum}") + full_range_start = max(last_range_zero, last_range_nine) + + shared_prefix = min_str[: full_range_start - 1] + range_digit_min, range_digit_max = ( + min_str[full_range_start - 1], + max_str[full_range_start - 1], + ) + + pattern = rf"{shared_prefix}[{range_digit_min}-{range_digit_max}]" + + num_0_9_chars = len(max_str) - full_range_start + if num_0_9_chars: + pattern += rf"[0-9]{{{num_0_9_chars}}}" + + return rf"({pattern})" + + +def get_positive_int_range_pattern(minimum: int, maximum: int) -> str: + """ + Generates a regex pattern for positive integers within a specified range. + + This function creates a regex pattern that matches positive integers from `minimum` to `maximum`. + It handles ranges with finite and infinite upper bounds, and can include zero explicitly if + needed. + + The function splits the range into subranges suitable for pattern generation, and combines + the patterns for each subrange using alternation (the '|' operator). + + Args: + minimum (int or inf): The lower bound of the integer range (must be >= 0). + maximum (int or inf): The upper bound of the integer range (must be >= 0 or infinity). + + Returns: + str: A regex pattern string that matches all positive integers in the range. + """ + assert minimum >= 0 + assert maximum >= 0 + + # special case, 0 through 10000... allows single simple pattern + if minimum == 0 and set(str(maximum)[1:]) == set("0") and str(maximum)[0] == "1": + return rf"(({maximum})|([1-9][0-9]{{0,{len(str(maximum))-2}}}))" + + # Handle the case where zero needs to be included explicitly. + if minimum == 0 and maximum == 0: + return "" # no pattern, 0 is handled by calling fn + elif minimum == 0: + minimum = 1 + explicit_zero = True # Flag to include OR Zero (`|0`) in the final pattern. + else: + explicit_zero = False + + if maximum == float("inf"): + # Handle infinite upper bound. + # Create and OR two patterns: (minimum, lower_maximum - 1) | (lower_maximum, infinity) + lower_maximum = 10 ** math.ceil(math.log10(minimum + 1)) - 1 + lower_maximum_pattern = "|".join( + [ + get_subrange_pattern(sub_min, sub_max) + for sub_min, sub_max in get_subranges(minimum, lower_maximum) + ] + ) + lower_max_to_infinity_pattern = rf"[\d]{{{len(str(lower_maximum))+1},}}" + pattern = f"({lower_max_to_infinity_pattern}|{lower_maximum_pattern})" + else: + pattern = "|".join( + [ + get_subrange_pattern(sub_min, sub_max) + for sub_min, sub_max in get_subranges(minimum, maximum) + ] + ) + + if explicit_zero: + pattern = rf"(({pattern}))" + + return pattern + + +def get_int_pattern(minimum=None, maximum=None, safe_subset: bool = False) -> str: + """ + This function generates a regex pattern that matches integers from `minimum` to `maximum`, + inclusive. It handles negative ranges, positive ranges, zero, and ranges that span both negative + and positive numbers. + + If no bounds are specified, it defaults to matching all integers. The `safe_subset` parameter + can be used to limit the range to safe integer values (e.g., to avoid excessively large numbers). + + Args: + minimum (int, (+/-)inf, optional): The lower bound of the integer range. Defaults to negative infinity. + maximum (int, (+/-)inf, optional): The upper bound of the integer range. Defaults to positive infinity. + safe_subset (bool, optional): If True, uses SAFE_INT_MIN and SAFE_INT_MAX as default bounds. + + Returns: + str: A regex pattern string that matches all integers in the specified range. + """ + # handle safe subset of range + if minimum is None: + minimum = SAFE_INT_MIN if safe_subset else -float("inf") + if maximum is None: + maximum = SAFE_INT_MAX if safe_subset else float("inf") + + if (minimum, maximum) == (-float("inf"), float("inf")): + return INTEGER + + assert minimum <= maximum + + if minimum == maximum == 0: + patterns = [] + elif minimum >= 0 and maximum >= 0: + patterns = [get_positive_int_range_pattern(minimum, maximum)] + elif minimum < 0 and maximum <= 0: + # entirely negative range: prefix with `-` and calculate abs of range + abs_pattern = get_positive_int_range_pattern(max(abs(maximum), 1), abs(minimum)) + patterns = [rf"-({abs_pattern})"] + elif maximum == -minimum: + maximum_pattern = get_positive_int_range_pattern(0, maximum) + patterns = [rf"((-?){maximum_pattern})"] + else: # minimum < 0 and maximum > 0: + # positive component of range | negative component + minimum_pattern = get_positive_int_range_pattern(0, abs(minimum)) + maximum_pattern = get_positive_int_range_pattern(0, maximum) + patterns = [rf"(-({minimum_pattern}))|({maximum_pattern})"] + + if minimum <= 0 <= maximum: + patterns.append("(-?0)") + + return "|".join(patterns) + + +def get_str_pattern( + min_length: Optional[int] = None, + max_length: Optional[int] = None, + safe_subset: bool = False, +) -> str: + if min_length is None and max_length is None and not safe_subset: + return STRING + elif min_length and max_length and int(max_length or 0) < int(min_length or 0): + raise ValueError("maxLength must be greater than or equal to minLength") + elif (min_length or 0) < 0 or (max_length or 0) < 0: + raise ValueError("minLength and maxLength must be greater than or equal to 0") + + range_begin = str(min_length) if min_length else "" + if max_length is None: + range_end = str(SAFE_STR_MAX_LEN) if safe_subset else "" + else: + range_end = str(max_length) + + return f'"{STRING_INNER}{{{range_begin},{range_end}}}"' diff --git a/outlines/generate/json.py b/outlines/generate/json.py index 6209840e2..cbbcc3760 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -18,6 +18,7 @@ def json( schema_object: Union[str, object, Callable], sampler: Sampler = multinomial(), whitespace_pattern: Optional[str] = None, + safe_subset: bool = True, ) -> SequenceGeneratorAdapter: """ Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. @@ -36,6 +37,14 @@ def json( whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + safe_subset + Use a subset of json schema which performs better with language models. + If you want to all the model to generate any json structure, set to False. + Changes the following: + - If whitespace_pattern is None, sets whitespace pattern to WHITESPACE (r"[ ]?") + - If unconstrained integer is used, constrain integer to *roughly* the int64 range [-1e19, 1e19] + - If unconstrained string is used, constrain it to max of 256 characters Returns ------- @@ -45,17 +54,17 @@ def json( """ if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern, safe_subset) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: schema_object.parse_raw(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern, safe_subset) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) elif isinstance(schema_object, str): schema = schema_object - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema(schema, whitespace_pattern, safe_subset) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) else: diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 7565ff642..459068275 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,5 +1,7 @@ import json +import random import re +import string as pystring from typing import List, Literal, Union import interegular @@ -13,16 +15,21 @@ INTEGER, NULL, NUMBER, + SAFE_WHITESPACE, STRING, STRING_INNER, TIME, UUID, - WHITESPACE, build_regex_from_schema, + get_int_pattern, get_schema_from_signature, + get_str_pattern, to_regex, ) +SAFE_INT = get_int_pattern(safe_subset=True) +SAFE_STR = get_str_pattern(safe_subset=True) + def test_function_basic(): def test_function(foo: str, bar: List[int]): @@ -71,7 +78,7 @@ class User(BaseModel): ) def test_match_integer(pattern, does_match): step = {"title": "Foo", "type": "integer"} - regex = to_regex(None, step) + regex = to_regex(None, step, safe_subset=False) assert regex == INTEGER value = pattern["integer"] @@ -83,6 +90,102 @@ def test_match_integer(pattern, does_match): assert match is None +@pytest.mark.parametrize( + "minimum,maximum", + [ + (0, 0), + (-1, 0), + (0, 1), + (-15, 0), + (0, 15), + (-1, 1), + (-15, 15), + (-1234, 56), + (-56, 1234), + (-9, 9), + (-10, 10), + (-9, 10), + (-10, 9), + (123, 199), + (123, 456), + (5600, 5678), + (550, 560), + (-12345, 3423), + (50, 10000), + (0, 1000), + (-100000, 0), + (-100000, 100000), + ], +) +def test_int_range_pattern(minimum, maximum): + pattern = get_int_pattern(minimum, maximum) + fsm = interegular.parse_pattern(pattern).to_fsm().reduce() + pattern_numbers = {"".join(s) for s in fsm.strings()} + range_numbers = set(map(str, range(minimum, maximum + 1))) + if "0" in range_numbers: + range_numbers.add("-0") + assert pattern_numbers == range_numbers + + # logarithmic space complexity + assert len(fsm.states) <= (len(str(minimum)) + len(str(maximum))) * 2 + + +def test_int_range_unconstrained(): + # test unconstrained + pattern = get_int_pattern(float("-inf"), float("inf")) + fsm = interegular.parse_pattern(pattern).to_fsm().reduce() + assert get_int_pattern(None, None) == pattern + assert fsm.accepts("0") + assert fsm.accepts("-1") + assert fsm.accepts("1") + assert fsm.accepts("-98427983498234893274983274892") + assert fsm.accepts("2994399439493294329432984932") + + assert not fsm.accepts("1.1") + assert not fsm.accepts("1.0") + assert not fsm.accepts("1.0") + assert not fsm.accepts("one") + + assert len(fsm.states) < 5 + + +def test_int_range_min_zero(): + # test min zero + pattern = get_int_pattern(0, float("inf")) + fsm = interegular.parse_pattern(pattern).to_fsm() + assert fsm.accepts("0") + assert not fsm.accepts("-1") + assert fsm.accepts("1") + assert not fsm.accepts("-98427983498234893274983274892") + assert fsm.accepts("2994399439493294329432984932") + + +def test_int_range_max_zero(): + # test min zero + pattern = get_int_pattern(-float("inf"), 0) + fsm = interegular.parse_pattern(pattern).to_fsm() + assert fsm.accepts("0") + assert fsm.accepts("-1") + assert not fsm.accepts("1") + assert fsm.accepts("-98427983498234893274983274892") + assert not fsm.accepts("2994399439493294329432984932") + + +def test_int_range_max_minus_32(): + # test min zero + pattern = get_int_pattern(-float("inf"), -32) + fsm = interegular.parse_pattern(pattern).to_fsm() + assert not fsm.accepts("0") + assert not fsm.accepts("-1") + assert not fsm.accepts("1") + assert not fsm.accepts("32") + assert fsm.accepts("-32") + assert fsm.accepts("-33") + assert fsm.accepts("-39482929438") + assert fsm.accepts("-98427983498234893274983274892") + assert not fsm.accepts("2994399439493294329432984932") + + @pytest.mark.parametrize( "pattern,does_match", [ @@ -110,6 +213,56 @@ def test_match_number(pattern, does_match): assert match is None +@pytest.mark.parametrize( + "min_len,max_len,safe_subset,expected_max,errors", + [ + # if no max and no safe mode, any length allowed + (None, None, False, None, False), + # max_len is None, use safe max_len of 256 + (None, None, True, 256, False), + (0, None, True, 256, False), + # if max_len is specified, it overrides safe_subset rules + (None, 500, True, 500, False), + (0, 500, True, 500, False), + # if min_len specification has no effect + (3, 500, True, 500, False), + (30, 500, True, 500, False), + (300, 500, True, 500, False), + # illegal + (-1, None, True, None, True), + (30, 20, True, None, True), + ], +) +def test_get_str_pattern(min_len, max_len, safe_subset, expected_max, errors): + if errors: + with pytest.raises(ValueError): + get_str_pattern(min_len, max_len, safe_subset) + return + + pattern = get_str_pattern(min_len, max_len, safe_subset) + + # verify str len in (min_len, max_len) + def str_of_len(str_len): + s = "".join(random.choices(pystring.ascii_letters + pystring.digits, k=str_len)) + return f'"{s}"' + + # verify min_len held + min_len = min_len or 0 + assert re.match(pattern, str_of_len(min_len)) + if min_len != 0: + assert not re.match(pattern, str_of_len(min_len - 1)) + + # verify expected_max accurate + if expected_max is not None: + assert re.match(pattern, str_of_len(expected_max)) + assert re.match(pattern, str_of_len(max(expected_max - 1, min_len))) + assert not re.match(pattern, str_of_len(expected_max + 1)) + else: + assert re.match(pattern, str_of_len(max(min_len, 100))) + assert re.match(pattern, str_of_len(max(min_len, 1000))) + assert re.match(pattern, str_of_len(max(min_len, 100000))) + + @pytest.mark.parametrize( "schema,regex,examples", [ @@ -267,11 +420,11 @@ def test_match_number(pattern, does_match): "title": "Foo", "type": "object", "properties": { - "count": {"title": "Count", "type": "integer", "minDigits": 3} + "count": {"title": "Count", "type": "integer", "minimum": 100} }, "required": ["count"], }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,})[ ]?\\}', + '\\{[ ]?"count"[ ]?:[ ]?(([\\d]{4,}|([1-9][0-9]{2})))[ ]?\\}', [('{ "count": 10 }', False), ('{ "count": 100 }', True)], ), # integer with maximum digits @@ -280,14 +433,14 @@ def test_match_number(pattern, does_match): "title": "Foo", "type": "object", "properties": { - "count": {"title": "Count", "type": "integer", "maxDigits": 3} + "count": {"title": "Count", "type": "integer", "maximum": 999} }, "required": ["count"], }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{,2})[ ]?\\}', + '\\{[ ]?"count"[ ]?:[ ]?((-(([\\d]{2,}|([1-9]))))|((0|(([1-9])|([1-9][0-9]{1})|([1-9][0-9]{2})))))[ ]?\\}', [('{ "count": 100 }', True), ('{ "count": 1000 }', False)], ), - # integer with minimum and maximum digits + # integer with minimum and maximum ( { "title": "Foo", @@ -296,13 +449,13 @@ def test_match_number(pattern, does_match): "count": { "title": "Count", "type": "integer", - "minDigits": 3, - "maxDigits": 5, + "minimum": 50, + "maximum": 50000, } }, "required": ["count"], }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,4})[ ]?\\}', + '\\{[ ]?"count"[ ]?:[ ]?(([5-9][0-9]{1})|([1-9][0-9]{2})|([1-9][0-9]{3})|([1-4][0-9]{4})|(5000[0-0]))[ ]?\\}', [ ('{ "count": 10 }', False), ('{ "count": 100 }', True), @@ -420,7 +573,7 @@ def test_match_number(pattern, does_match): # array ( {"title": "Foo", "type": "array", "items": {"type": "number"}}, - rf"\[{WHITESPACE}(({NUMBER})(,{WHITESPACE}({NUMBER})){{0,}})?{WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}(({NUMBER})(,{SAFE_WHITESPACE}({NUMBER})){{0,}})?{SAFE_WHITESPACE}\]", [("[1e+9,1.3]", True), ("[]", True), ("[1", False)], ), # array with a set length of 1 @@ -432,7 +585,7 @@ def test_match_number(pattern, does_match): "minItems": 1, "maxItems": 1, }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{0,0}}){WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}(({INTEGER})(,{SAFE_WHITESPACE}({INTEGER})){{0,0}}){SAFE_WHITESPACE}\]", [("[1]", True), ("[1,2]", False), ('["a"]', False), ("[]", False)], ), # array with a set length greather than 1 @@ -444,7 +597,7 @@ def test_match_number(pattern, does_match): "minItems": 3, "maxItems": 3, }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{2,2}}){WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}(({INTEGER})(,{SAFE_WHITESPACE}({INTEGER})){{2,2}}){SAFE_WHITESPACE}\]", [("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)], ), # array with length 0 @@ -456,7 +609,7 @@ def test_match_number(pattern, does_match): "minItems": 0, "maxItems": 0, }, - rf"\[{WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}\]", [("[1]", False), ("[]", True), ("[1,2,3]", False), ("[1,2,3,4]", False)], ), # object @@ -473,7 +626,7 @@ def test_match_number(pattern, does_match): }, "required": ["test_dict"], }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + rf"""\{{{SAFE_WHITESPACE}"test_dict"{SAFE_WHITESPACE}:{SAFE_WHITESPACE}\{{{SAFE_WHITESPACE}({STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{STRING}({SAFE_WHITESPACE},{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{STRING}){{0,}})?{SAFE_WHITESPACE}\}}{SAFE_WHITESPACE}\}}""", [ ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), ("""{ "test_dict":{"foo":"bar" }}""", True), @@ -499,7 +652,7 @@ def test_match_number(pattern, does_match): }, "required": ["test_dict"], }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", + rf"""\{{{SAFE_WHITESPACE}"test_dict"{SAFE_WHITESPACE}:{SAFE_WHITESPACE}\{{{SAFE_WHITESPACE}({STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}\{{{SAFE_WHITESPACE}({STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{INTEGER}({SAFE_WHITESPACE},{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{INTEGER}){{0,}})?{SAFE_WHITESPACE}\}}({SAFE_WHITESPACE},{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}\{{{SAFE_WHITESPACE}({STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{INTEGER}({SAFE_WHITESPACE},{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE}:{SAFE_WHITESPACE}{INTEGER}){{0,}})?{SAFE_WHITESPACE}\}}){{0,}})?{SAFE_WHITESPACE}\}}{SAFE_WHITESPACE}\}}""", [ ( """{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""", @@ -559,7 +712,7 @@ def test_match_number(pattern, does_match): "title": "Foo", "prefixItems": [{"type": "string"}, {"type": "integer"}], }, - rf"\[{WHITESPACE}{STRING}{WHITESPACE},{WHITESPACE}{INTEGER}{WHITESPACE}\]", + rf"\[{SAFE_WHITESPACE}{STRING}{SAFE_WHITESPACE},{SAFE_WHITESPACE}{INTEGER}{SAFE_WHITESPACE}\]", [('["a", 1]', True), ('["a", 1, 1]', False), ("[]", False)], ), # Nested schema @@ -751,7 +904,9 @@ def test_match_number(pattern, does_match): def test_match(schema, regex, examples): interegular.parse_pattern(regex) schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) + test_regex = build_regex_from_schema( + schema, whitespace_pattern=SAFE_WHITESPACE, safe_subset=False + ) assert test_regex == regex for string, does_match in examples: @@ -1000,10 +1155,10 @@ class MockModel(BaseModel): # assert any ws pattern can be used if whitespace_pattern == "abc": - build_regex_from_schema(schema, whitespace_pattern) + build_regex_from_schema(schema, whitespace_pattern, safe_subset=False) return - pattern = build_regex_from_schema(schema, whitespace_pattern) + pattern = build_regex_from_schema(schema, whitespace_pattern, safe_subset=False) mock_result_mult_ws = ( """{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}"""