diff --git a/docs/reference/generation/json.md b/docs/reference/generation/json.md index da9f14729..0f75a198c 100644 --- a/docs/reference/generation/json.md +++ b/docs/reference/generation/json.md @@ -42,6 +42,14 @@ print(result) generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*") ``` +!!! Note "Non-Strict Mode" + Because models may exhaust their context window before a valid schema is generated, an error resulting from from an invalid generation may occur. This is particularly troublesome when an error interrupts a batch workload. To ensure `generate.json` returns a dict containing error details for invalid sequences rather than raising an error, use the following: + + ```python + generator = generate.json(model, User, strict=False) + ``` + + !!! Note "Performance" `generation.json` computes an index that helps Outlines guide generation. This can take some time, but only needs to be done once. If you want to generate several times with the same schema make sure that you only call `generate.json` once. diff --git a/outlines/generate/json.py b/outlines/generate/json.py index f75878d29..588e6fc3b 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -18,6 +18,7 @@ def json( schema_object: Union[str, object, Callable], sampler: Sampler = multinomial(), whitespace_pattern: Optional[str] = None, + strict=True, ) -> SequenceGeneratorAdapter: """ Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. @@ -36,6 +37,10 @@ def json( whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + strict + If strict mode is enabled, generation errors which don't conform to the specified pattern + or aren't valid JSON will result in an error. Outlines guarantees generation complies with + a pattern, but patterns often allow for infinite repetition and exhaust the model_max_length. Returns ------- @@ -43,21 +48,39 @@ def json( transforms the result if BaseModel is used. """ + + def maybe_strict_formatter(formatter): + """If strict, use normal formatter. Otherwise, return error dict on failure""" + if strict: + return formatter + + def allow_fail_formatter(generated_output): + try: + return formatter(generated_output) + except Exception as e: + return { + "error": str(e), + "error_type": type(e).__name__, + "output": generated_output, + } + + return allow_fail_formatter + if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: schema_object.parse_raw(x) + generator.format_sequence = maybe_strict_formatter(schema_object.parse_raw) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: pyjson.loads(x) + generator.format_sequence = maybe_strict_formatter(pyjson.loads) elif isinstance(schema_object, str): schema = schema_object regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: pyjson.loads(x) + generator.format_sequence = maybe_strict_formatter(pyjson.loads) else: raise ValueError( f"Cannot parse schema {schema_object}. The schema must be either " diff --git a/tests/generate/test_generate_json.py b/tests/generate/test_generate_json.py new file mode 100644 index 000000000..00232ef2f --- /dev/null +++ b/tests/generate/test_generate_json.py @@ -0,0 +1,115 @@ +import json +import string + +import pytest +from pydantic import BaseModel, ValidationError + +from outlines import generate + + +class MockCharacterTokenizer: + def __init__(self): + characters = set( + string.ascii_letters + + string.digits + + string.punctuation + + string.whitespace + ) + self.vocabulary = {tok: tok_id for tok_id, tok in enumerate(characters)} + self.vocabulary["eos"] = len(characters) + self.special_tokens = {"eos"} + self.eos_token_id = len(characters) + + def convert_token_to_string(self, token): + return token + + +class MockModel: + def __init__(self, generated): + self.generated = generated + self.tokenizer = MockCharacterTokenizer() + + def generate(self, *args, **kwargs): + return self.generated + + +mock_json_schema = json.dumps( + { + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + "additionalProperties": False, + } +) + + +class MockPydanticModel(BaseModel): + message: str + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_strict_success(schema): + model = MockModel(generated='{"message": "foo"}') + generator = generate.json(model, schema) + generator("hi") + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_strict_success_batch(schema): + model = MockModel( + generated=[ + '{"message": "foo"}', + '{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"}', + ] + ) + generator = generate.json(model, schema) + for output in generator("hi"): + pass + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_strict_fail(schema): + model = MockModel(generated='{"message": "foo') + generator = generate.json(model, schema) + with pytest.raises((json.decoder.JSONDecodeError, ValidationError)): + generator("hi") + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_strict_fail_batch(schema): + model = MockModel( + generated=[ + '{"message": "foo"}', + '{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"', + ] + ) + generator = generate.json(model, schema) + with pytest.raises((json.decoder.JSONDecodeError, ValidationError)): + generator("hi") + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_non_strict_evade_failure(schema): + model = MockModel(generated='{"message": "foo') + generator = generate.json(model, schema, strict=False) + result = generator("hi") + assert result["error_type"] in ("JSONDecodeError", "ValidationError") + assert result["output"] == model.generated + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_non_strict_evade_failure_batch(schema): + model = MockModel( + generated=[ + '{"message": "foo"}', + '{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"', + ] + ) + generator = generate.json(model, schema, strict=False) + result = generator("hi") + if isinstance(schema, str): + assert result[0] == json.loads(model.generated[0]) + else: + assert result[0] == schema.parse_raw(model.generated[0]) + assert result[1]["error_type"] in ("JSONDecodeError", "ValidationError") + assert result[1]["output"] == model.generated[1]