From 11af6cee3dfce185e713974fb56724c02f57af34 Mon Sep 17 00:00:00 2001 From: Samuel MAGNAN Date: Tue, 11 Jun 2024 05:14:14 -0400 Subject: [PATCH] Support min/max number of digits for numbers in JSON Schema (#932) This would allow to mitigate some repetition issues that certain LLMs have. This should resolve #847 Co-authored-by: Samuel Magnan --- outlines/fsm/json_schema.py | 86 ++++++++++++++++++- tests/fsm/test_json_schema.py | 156 ++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+), 2 deletions(-) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 3bd4816a9..810ef5910 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -2,7 +2,7 @@ import json import re import warnings -from typing import Callable, Optional +from typing import Callable, Optional, Tuple from jsonschema.protocols import Validator from pydantic import create_model @@ -96,6 +96,47 @@ def _get_num_items_pattern(min_items, max_items, whitespace_pattern): return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}" +def validate_quantifiers( + min_bound: Optional[str], max_bound: Optional[str], start_offset: int = 0 +) -> Tuple[str, str]: + """ + Ensures that the bounds of a number are valid. Bounds are used as quantifiers in the regex. + + Parameters + ---------- + min_bound + The minimum value that the number can take. + max_bound + The maximum value that the number can take. + start_offset + Number of elements that are already present in the regex but still need to be counted. + ex: if the regex is already "(-)?(0|[1-9][0-9])", we will always have at least 1 digit, so the start_offset is 1. + + Returns + ------- + min_bound + The minimum value that the number can take. + max_bound + The maximum value that the number can take. + + Raises + ------ + ValueError + If the minimum bound is greater than the maximum bound. + + TypeError or ValueError + If the minimum bound is not an integer or None. + or + If the maximum bound is not an integer or None. + """ + min_bound = "" if min_bound is None else str(int(min_bound) - start_offset) + max_bound = "" if max_bound is None else str(int(max_bound) - start_offset) + if min_bound and max_bound: + if int(max_bound) < int(min_bound): + raise ValueError("max bound must be greater than or equal to min bound") + return min_bound, max_bound + + def to_regex( resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None ): @@ -263,7 +304,7 @@ def to_regex( if int(max_items) < int(min_items): raise ValueError( "maxLength must be greater than or equal to minLength" - ) + ) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume) except ValueError: pass return f'"{STRING_INNER}{{{min_items},{max_items}}}"' @@ -291,9 +332,50 @@ def to_regex( return type_to_regex["string"] elif instance_type == "number": + bounds = { + "minDigitsInteger", + "maxDigitsInteger", + "minDigitsFraction", + "maxDigitsFraction", + "minDigitsExponent", + "maxDigitsExponent", + } + if bounds.intersection(set(instance.keys())): + min_digits_integer, max_digits_integer = validate_quantifiers( + instance.get("minDigitsInteger"), + instance.get("maxDigitsInteger"), + start_offset=1, + ) + min_digits_fraction, max_digits_fraction = validate_quantifiers( + instance.get("minDigitsFraction"), instance.get("maxDigitsFraction") + ) + min_digits_exponent, max_digits_exponent = validate_quantifiers( + instance.get("minDigitsExponent"), instance.get("maxDigitsExponent") + ) + integers_quantifier = ( + f"{{{min_digits_integer},{max_digits_integer}}}" + if min_digits_integer or max_digits_integer + else "*" + ) + fraction_quantifier = ( + f"{{{min_digits_fraction},{max_digits_fraction}}}" + if min_digits_fraction or max_digits_fraction + else "+" + ) + exponent_quantifier = ( + f"{{{min_digits_exponent},{max_digits_exponent}}}" + if min_digits_exponent or max_digits_exponent + else "+" + ) + return rf"((-)?(0|[1-9][0-9]{integers_quantifier}))(\.[0-9]{fraction_quantifier})?([eE][+-][0-9]{exponent_quantifier})?" return type_to_regex["number"] elif instance_type == "integer": + if "minDigits" in instance or "maxDigits" in instance: + min_digits, max_digits = validate_quantifiers( + instance.get("minDigits"), instance.get("maxDigits"), start_offset=1 + ) + return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})" return type_to_regex["integer"] elif instance_type == "array": diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index e691db374..f2cc4115b 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -218,6 +218,162 @@ def test_match_number(pattern, does_match): '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}', [('{ "count": 100 }', True)], ), + # integer with minimum digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": {"title": "Count", "type": "integer", "minDigits": 3} + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,})[ ]?\\}', + [('{ "count": 10 }', False), ('{ "count": 100 }', True)], + ), + # integer with maximum digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": {"title": "Count", "type": "integer", "maxDigits": 3} + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{,2})[ ]?\\}', + [('{ "count": 100 }', True), ('{ "count": 1000 }', False)], + ), + # integer with minimum and maximum digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "integer", + "minDigits": 3, + "maxDigits": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,4})[ ]?\\}', + [ + ('{ "count": 10 }', False), + ('{ "count": 100 }', True), + ('{ "count": 10000 }', True), + ('{ "count": 100000 }', False), + ], + ), + # number + ( + { + "title": "Foo", + "type": "object", + "properties": {"count": {"title": "Count", "type": "number"}}, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}', + [('{ "count": 100 }', True), ('{ "count": 100.5 }', True)], + ), + # number with min and max integer digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "number", + "minDigitsInteger": 3, + "maxDigitsInteger": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}', + [ + ('{ "count": 10.005 }', False), + ('{ "count": 100.005 }', True), + ('{ "count": 10000.005 }', True), + ('{ "count": 100000.005 }', False), + ], + ), + # number with min and max fraction digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "number", + "minDigitsFraction": 3, + "maxDigitsFraction": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]{3,5})?([eE][+-][0-9]+)?[ ]?\\}', + [ + ('{ "count": 1.05 }', False), + ('{ "count": 1.005 }', True), + ('{ "count": 1.00005 }', True), + ('{ "count": 1.000005 }', False), + ], + ), + # number with min and max exponent digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "number", + "minDigitsExponent": 3, + "maxDigitsExponent": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]{3,5})?[ ]?\\}', + [ + ('{ "count": 1.05e1 }', False), + ('{ "count": 1.05e+001 }', True), + ('{ "count": 1.05e-00001 }', True), + ('{ "count": 1.05e0000001 }', False), + ], + ), + # number with min and max integer, fraction and exponent digits + ( + { + "title": "Foo", + "type": "object", + "properties": { + "count": { + "title": "Count", + "type": "number", + "minDigitsInteger": 3, + "maxDigitsInteger": 5, + "minDigitsFraction": 3, + "maxDigitsFraction": 5, + "minDigitsExponent": 3, + "maxDigitsExponent": 5, + } + }, + "required": ["count"], + }, + '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]{3,5})?([eE][+-][0-9]{3,5})?[ ]?\\}', + [ + ('{ "count": 1.05e1 }', False), + ('{ "count": 100.005e+001 }', True), + ('{ "count": 10000.00005e-00001 }', True), + ('{ "count": 100000.000005e0000001 }', False), + ], + ), # array ( {"title": "Foo", "type": "array", "items": {"type": "number"}},