From 9c74d7c82ace6df8a24df973ee471371ee79705b Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 8 Feb 2024 17:11:31 -0600 Subject: [PATCH] Allow users to pass custom whitespace pattern for JSON-structured generation --- docs/reference/json.md | 12 +++- outlines/fsm/json_schema.py | 68 ++++++++++++------- outlines/generate/json.py | 38 +++++++++-- outlines/serve/vllm.py | 12 ++-- tests/fsm/test_json_schema.py | 29 ++++++++ .../generate/test_integration_transfomers.py | 24 +++++++ 6 files changed, 148 insertions(+), 35 deletions(-) diff --git a/docs/reference/json.md b/docs/reference/json.md index 04d238e51..84115af88 100644 --- a/docs/reference/json.md +++ b/docs/reference/json.md @@ -26,13 +26,21 @@ class User(BaseModel): id: int -model = models.transformers("mistralai/Mistral-7B") +model = models.transformers("mistralai/Mistral-7B-v0.1") generator = text.generate.json(model, User) result = generator("Create a user profile with the fields name, last_name and id") print(result) # User(name="John", last_name="Doe", id=11) ``` +!!! warning "JSON and whitespaces" + + By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string: + + ```python + generator = text.generate.json(model, User, whitespace_pattern="") + ``` + ## From a function's signature Outlines can infer the structure of the output from the signature of a function. The result is a dictionary, and can be passed directly to the function using the usual dictionary expansion syntax `**`: @@ -44,7 +52,7 @@ from outlines import text def add(a: int, b: int): return a + b -model = models.transformers("mistralai/Mistral-7B") +model = models.transformers("mistralai/Mistral-7B-v0.1") generator = text.generate.json(model, add) result = generator("Return two integers named a and b respectively. a is odd and b even.") diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 57b255d4c..dbae20281 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -1,7 +1,7 @@ import inspect import json import re -from typing import Callable, Union +from typing import Callable, Optional, Union from jsonschema.protocols import Validator from pydantic import BaseModel, create_model @@ -38,7 +38,9 @@ } -def build_regex_from_object(object: Union[str, Callable, BaseModel]): +def build_regex_from_object( + object: Union[str, Callable, BaseModel], whitespace_pattern: Optional[str] = None +): """Turn a JSON schema into a regex that matches any JSON object that follows this schema. @@ -54,6 +56,9 @@ def build_regex_from_object(object: Union[str, Callable, BaseModel]): ---------- schema A string that represents a JSON Schema. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` Returns ------- @@ -83,10 +88,12 @@ def build_regex_from_object(object: Union[str, Callable, BaseModel]): resolver = registry.resolver() content = schema.contents - return to_regex(resolver, content) + return to_regex(resolver, content, whitespace_pattern) -def to_regex(resolver: Resolver, instance: dict): +def to_regex( + resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None +): """Translate a JSON Schema instance into a regex that validates the schema. Note @@ -105,8 +112,15 @@ def to_regex(resolver: Resolver, instance: dict): An object that resolves references to other instances within a schema instance The instance to translate + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` """ + # set whitespace pattern + if whitespace_pattern is None: + whitespace_pattern = WHITESPACE + if "properties" in instance: regex = "" regex += r"\{" @@ -120,12 +134,12 @@ def to_regex(resolver: Resolver, instance: dict): if any(is_required): last_required_pos = max([i for i, value in enumerate(is_required) if value]) for i, (name, value) in enumerate(properties.items()): - subregex = f'{WHITESPACE}"{name}"{WHITESPACE}:{WHITESPACE}' - subregex += to_regex(resolver, value) + subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' + subregex += to_regex(resolver, value, whitespace_pattern) if i < last_required_pos: - subregex = f"{subregex}{WHITESPACE}," + subregex = f"{subregex}{whitespace_pattern}," elif i > last_required_pos: - subregex = f"{WHITESPACE},{subregex}" + subregex = f"{whitespace_pattern},{subregex}" regex += subregex if is_required[i] else f"({subregex})?" # If no property is required, we have to create a possible pattern for each property in which # it's the last one necessarilly present. Then, we add the others as optional before and after @@ -134,41 +148,47 @@ def to_regex(resolver: Resolver, instance: dict): else: property_subregexes = [] for i, (name, value) in enumerate(properties.items()): - subregex = f'{WHITESPACE}"{name}"{WHITESPACE}:{WHITESPACE}' - subregex += to_regex(resolver, value) + subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' + subregex += to_regex(resolver, value, whitespace_pattern) property_subregexes.append(subregex) possible_patterns = [] for i in range(len(property_subregexes)): pattern = "" for subregex in property_subregexes[:i]: - pattern += f"({subregex}{WHITESPACE},)?" + pattern += f"({subregex}{whitespace_pattern},)?" pattern += property_subregexes[i] for subregex in property_subregexes[i + 1 :]: - pattern += f"({WHITESPACE},{subregex})?" + pattern += f"({whitespace_pattern},{subregex})?" possible_patterns.append(pattern) regex += f"({'|'.join(possible_patterns)})?" - regex += f"{WHITESPACE}" + r"\}" + regex += f"{whitespace_pattern}" + r"\}" return regex # To validate against allOf, the given data must be valid against all of the # given subschemas. elif "allOf" in instance: - subregexes = [to_regex(resolver, t) for t in instance["allOf"]] + subregexes = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"] + ] subregexes_str = [f"{subregex}" for subregex in subregexes] return rf"({''.join(subregexes_str)})" # To validate against `anyOf`, the given data must be valid against # any (one or more) of the given subschemas. elif "anyOf" in instance: - subregexes = [to_regex(resolver, t) for t in instance["anyOf"]] + subregexes = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"] + ] return rf"({'|'.join(subregexes)})" # To validate against oneOf, the given data must be valid against exactly # one of the given subschemas. elif "oneOf" in instance: - subregexes = [to_regex(resolver, t) for t in instance["oneOf"]] + subregexes = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] + ] xor_patterns = [] # json schema validation ensured there is no overlapping schemas in oneOf @@ -195,7 +215,7 @@ def to_regex(resolver: Resolver, instance: dict): elif "$ref" in instance: path = f"{instance['$ref']}" instance = resolver.lookup(path).contents - return to_regex(resolver, instance) + return to_regex(resolver, instance, whitespace_pattern) # The type keyword may either be a string or an array: # - If it's a string, it is the name of one of the basic types. @@ -254,14 +274,14 @@ def to_regex(resolver: Resolver, instance: dict): num_repeats = rf"{{{max(min_items - 1, 0)},}}" else: if max_items < 1: - return rf"\[{WHITESPACE}\]" + return rf"\[{whitespace_pattern}\]" num_repeats = rf"{{{max(min_items - 1, 0)},{max_items - 1}}}" allow_empty = "?" if min_items == 0 else "" if "items" in instance: - items_regex = to_regex(resolver, instance["items"]) - return rf"\[{WHITESPACE}(({items_regex})(,{WHITESPACE}({items_regex})){num_repeats}){allow_empty}{WHITESPACE}\]" + items_regex = to_regex(resolver, instance["items"], whitespace_pattern) + return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]" else: # Here we need to make the choice to exclude generating list of objects # if the specification of the object is not given, even though a JSON @@ -273,8 +293,8 @@ def to_regex(resolver: Resolver, instance: dict): {"type": "integer"}, {"type": "string"}, ] - regexes = [to_regex(resolver, t) for t in types] - return rf"\[{WHITESPACE}({'|'.join(regexes)})(,{WHITESPACE}({'|'.join(regexes)})){num_repeats}){allow_empty}{WHITESPACE}\]" + regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] + return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}){allow_empty}{whitespace_pattern}\]" elif instance_type == "boolean": return type_to_regex["boolean"] @@ -287,7 +307,9 @@ def to_regex(resolver: Resolver, instance: dict): # if the specification of the object is not give, even though a JSON # object that contains an object here would be valid under the specification. regexes = [ - to_regex(resolver, {"type": t}) for t in instance_type if t != "object" + to_regex(resolver, {"type": t}, whitespace_pattern) + for t in instance_type + if t != "object" ] return rf"({'|'.join(regexes)})" diff --git a/outlines/generate/json.py b/outlines/generate/json.py index 9dbf817cb..b81c438a3 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -1,6 +1,6 @@ import json as pyjson from functools import singledispatch -from typing import Callable, Union +from typing import Callable, Optional, Union from pydantic import BaseModel @@ -14,21 +14,49 @@ @singledispatch def json( - model, schema_object: Union[str, object, Callable], sampler: Sampler = multinomial() + model, + schema_object: Union[str, object, Callable], + sampler: Sampler = multinomial(), + whitespace_pattern: Optional[str] = None, ) -> SequenceGenerator: + """ + Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + schema_object: + The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable + that returns a JSON schema. + max_tokens: + The maximum number of tokens to generate. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + + Returns + ------- + A `SequenceGenerator` instance that generates text constrained by the schema_object and + transforms the result if BaseModel is used. + """ if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) - regex_str = build_regex_from_object(schema) + regex_str = build_regex_from_object(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: schema_object.parse_raw(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) - regex_str = build_regex_from_object(schema) + regex_str = build_regex_from_object(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) elif isinstance(schema_object, str): schema = schema_object - regex_str = build_regex_from_object(schema) + regex_str = build_regex_from_object(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) else: diff --git a/outlines/serve/vllm.py b/outlines/serve/vllm.py index 50fa55b62..48e9c87c7 100644 --- a/outlines/serve/vllm.py +++ b/outlines/serve/vllm.py @@ -2,7 +2,7 @@ import json import math from collections import defaultdict -from typing import DefaultDict, List +from typing import DefaultDict, List, Optional import torch @@ -105,8 +105,8 @@ def convert_token_to_string(token: str) -> str: class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema, llm): - """Compile the FSM that drives the JSON-structured generation. + def __init__(self, schema, llm, whitespace_pattern: Optional[str] = None): + """Compile the FSM that drives the JSON-guided generation. Parameters ---------- @@ -114,9 +114,11 @@ def __init__(self, schema, llm): A JSON schema that encodes the structure we want the model to generate llm An instance of `vllm.LLM` - + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` """ if isinstance(schema, dict): schema = json.dumps(schema) - regex_string = build_regex_from_object(schema) + regex_string = build_regex_from_object(schema, whitespace_pattern) super().__init__(regex_string, llm) diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index d38a7b6a5..833b3e884 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -540,3 +540,32 @@ def test_format(schema, regex, examples): assert match.span() == (0, len(string)) else: assert match is None + + +@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"]) +def test_json_schema_custom_whitespace_pattern(whitespace_pattern): + """assert whitespace_pattern setting respected""" + + class MockModel(BaseModel): + foo: int + bar: str + + # assert any ws pattern can be used + if whitespace_pattern == "abc": + build_regex_from_object(MockModel, whitespace_pattern) + return + + pattern = build_regex_from_object(MockModel, whitespace_pattern) + + mock_result_mult_ws = ( + """{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}""" + ) + mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}""" + + match_default_ws = re.fullmatch(pattern, mock_result_mult_ws) + if whitespace_pattern is None: + assert match_default_ws + else: + assert match_default_ws is None + + assert re.fullmatch(pattern, mock_result_maybe_ws) diff --git a/tests/generate/test_integration_transfomers.py b/tests/generate/test_integration_transfomers.py index 584c73717..35019b238 100644 --- a/tests/generate/test_integration_transfomers.py +++ b/tests/generate/test_integration_transfomers.py @@ -543,6 +543,30 @@ def test_transformers_logits_vocab_size(): assert sequence == "False" +def test_transformers_json_custom_ws(): + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + prompt = "Output some JSON with newlines" # try to force model to use newlines + + schema = """{ + "title": "spam", + "type": "object", + "properties": { + "foo" : {"type": "integer"}, + "bar": {"type": "integer"} + }, + "required": ["foo", "bar"] + } + """ + + rng = torch.Generator() + rng.manual_seed(0) + + generator = generate.json(model, schema, whitespace_pattern=r"[ ]?") + generator.format_sequence = lambda x: x # patch to return raw text + assert "\n" not in generator(prompt, max_tokens=500, rng=rng) + + def test_transformers_reduced_vocabulary_caching(): tokenizer = TransformerTokenizer("gpt2") tokenizer2 = TransformerTokenizer("gpt2")