Skip to content

Commit

Permalink
adds min/max control for all digits of intergers and numbers. This al…
Browse files Browse the repository at this point in the history
…lows to mitigate some LLM that get stuck on repeating the same number non stop.
  • Loading branch information
Samuel Magnan committed May 31, 2024
1 parent 538f77a commit 7f05cea
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,42 @@ def _get_num_items_pattern(min_items, max_items, whitespace_pattern):
return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}"


def validate_quantifiers(min_bound: Optional[str], max_bound: Optional[str]) -> (str, str):
"""
Ensures that the bounds of a number are valid. Bounds are used as quantifiers in the regex.
Parameters
----------
min_bound
The minimum value that the number can take.
max_bound
The maximum value that the number can take.
Returns
-------
min_bound
The minimum value that the number can take.
max_bound
The maximum value that the number can take.
Raises
------
ValueError
If the minimum bound is greater than the maximum bound.
TypeError or ValueError
If the minimum bound is not an integer or None.
or
If the maximum bound is not an integer or None.
"""
min_bound = "" if min_bound is None else str(int(min_bound))
max_bound = "" if max_bound is None else str(int(max_bound))
if min_bound and max_bound:
if int(max_bound) < int(min_bound):
raise ValueError("max bound must be greater than or equal to min bound")
return min_bound, max_bound


def to_regex(
resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None
):
Expand Down Expand Up @@ -248,7 +284,7 @@ def to_regex(
if int(max_items) < int(min_items):
raise ValueError(
"maxLength must be greater than or equal to minLength"
)
) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume)
except ValueError:
pass
return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
Expand Down Expand Up @@ -276,9 +312,29 @@ def to_regex(
return type_to_regex["string"]

elif instance_type == "number":
bounds = {"minDigitsInteger", "maxDigitsInteger", "minDigitsFraction", "maxDigitsFraction", "minDigitsExponent", "maxDigitsExponent"}
if bounds.intersection(set(instance.keys())):
min_digits_integer, max_digits_integer = validate_quantifiers(
instance.get("minDigitsInteger"), instance.get("maxDigitsInteger")
)
min_digits_fraction, max_digits_fraction = validate_quantifiers(
instance.get("minDigitsFraction"), instance.get("maxDigitsFraction")
)
min_digits_exponent, max_digits_exponent = validate_quantifiers(
instance.get("minDigitsExponent"), instance.get("maxDigitsExponent")
)
integers_quantifier = f"{{{min_digits_integer},{max_digits_integer}}}" if min_digits_integer or max_digits_integer else "*"
fraction_quantifier = f"{{{min_digits_fraction},{max_digits_fraction}}}" if min_digits_fraction or max_digits_fraction else "*"
exponent_quantifier = f"{{{min_digits_exponent},{max_digits_exponent}}}" if min_digits_exponent or max_digits_exponent else "*"
return rf"(-)?(0|[1-9][0-9]{integers_quantifier})(\.[0-9]{fraction_quantifier})?([eE][+-]?[0-9]{exponent_quantifier})?"
return type_to_regex["number"]

elif instance_type == "integer":
if "minDigits" in instance or "maxDigits" in instance:
min_digits, max_digits = validate_quantifiers(
instance.get("minDigits"), instance.get("maxDigits")
)
return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})"
return type_to_regex["integer"]

elif instance_type == "array":
Expand Down

0 comments on commit 7f05cea

Please sign in to comment.