From e1c96041c6ba372df560edcbbea37bde876cd8c1 Mon Sep 17 00:00:00 2001 From: bparis Date: Fri, 1 Dec 2023 10:30:28 +0100 Subject: [PATCH] Fix oneOf implementation of the json schema spec Implements XOR regex using negative lookaheads. --- outlines/text/json_schema.py | 11 ++++++++++- tests/text/test_json_schema.py | 15 ++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index 4044d225a..891f3d9c9 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -137,7 +137,16 @@ def to_regex(resolver: Resolver, instance: dict): # one of the given subschemas. elif "oneOf" in instance: subregexes = [to_regex(resolver, t) for t in instance["oneOf"]] - return rf"({'|'.join(subregexes)})" + + xor_patterns = [] + # json schema validation ensured there is no overlapping schemas in oneOf + for subregex in subregexes: + other_subregexes = filter(lambda r: r != subregex, subregexes) + other_subregexes_str = "|".join([f"{s}" for s in other_subregexes]) + negative_lookahead = f"(?!.*({other_subregexes_str}))" + xor_patterns.append(f"({subregex}){negative_lookahead}") + + return rf"({'|'.join(xor_patterns)})" # The enum keyword is used to restrict a value to a fixed set of values. It # must be an array with at least one element, where each element is unique. diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index 380392980..a0af780ef 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -220,10 +220,19 @@ def test_match_number(pattern, does_match): ( { "title": "Foo", - "oneOf": [{"type": "string"}, {"type": "number"}], + "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], }, - rf"({STRING}|{NUMBER})", - [("12.3", True), ('"a"', True), ('1.3"a"', False)], + rf"(({STRING})(?!.*({NUMBER}|{BOOLEAN}))|({NUMBER})(?!.*({STRING}|{BOOLEAN}))|({BOOLEAN})(?!.*({STRING}|{NUMBER})))", + [ + ("12.3", True), + ("true", True), + ('"a"', True), + ("null", False), + ("", False), + ("12true", False), + ('1.3"a"', False), + ('12.3true"a"', False), + ], ), # anyOf (