diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index c57cea7cd..aa8e9f79b 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -2,7 +2,8 @@ import json import re import warnings -from typing import Callable, Optional +from copy import deepcopy +from typing import Callable, List, Optional from jsonschema.protocols import Validator from pydantic import create_model @@ -39,7 +40,11 @@ } -def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): +def build_regex_from_schema( + schema: str, + whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, +): """Turn a JSON schema into a regex that matches any JSON object that follows this schema. @@ -58,6 +63,12 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization: + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. Returns ------- @@ -81,9 +92,76 @@ def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = Non resolver = registry.resolver() content = schema.contents + if enable_schema_optimization: + content = optimize_schema(content) return to_regex(resolver, content, whitespace_pattern) +def _is_null_type(instance: dict): + if "type" in instance and (instance["type"] == "null" or instance["type"] is None): + return True + if "const" in instance and ( + instance["const"] == "null" or instance["const"] is None + ): + return True + return False + + +def _has_null_type(instance_list: List[dict]): + for instance in instance_list: + if _is_null_type(instance): + return True + return False + + +def optimize_schema(instance): + instance_copy = deepcopy(instance) + if "$defs" in instance_copy: + instance_copy["$defs"] = { + key: optimize_schema(subinstance) + for key, subinstance in instance_copy["$defs"].items() + } + if "properties" in instance_copy: + new_optional_keys = set() + keys_to_remove = set() + for key, subinstance in instance_copy["properties"].items(): + subinstance = optimize_schema(subinstance) + if "type" in subinstance: + subinstance_type = subinstance["type"] + if subinstance_type == "null": + keys_to_remove.add(key) + elif ( + subinstance_type == "array" and subinstance.get("minItems", 0) == 0 + ): + new_optional_keys.add(key) + elif "anyOf" in subinstance and _has_null_type(subinstance["anyOf"]): + any_of_list = subinstance.pop("anyOf") + filtered_any_of_list = list( + filter(lambda d: not _is_null_type(d), any_of_list) + ) + if len(filtered_any_of_list) == 0: + keys_to_remove.add(key) + elif len(filtered_any_of_list) == 1: + subinstance = {**subinstance, **filtered_any_of_list[0]} + instance_copy["properties"][key] = subinstance + new_optional_keys.add(key) + else: + subinstance["anyOf"] = filtered_any_of_list + new_optional_keys.add(key) + if "required" in instance_copy: + instance_copy["required"] = [ + key + for key in instance_copy["required"] + if key not in new_optional_keys and key not in keys_to_remove + ] + instance_copy["properties"] = { + key: value + for key, value in instance_copy["properties"].items() + if key not in keys_to_remove + } + return instance_copy + + def _get_num_items_pattern(min_items, max_items, whitespace_pattern): # Helper function for arrays and objects min_items = int(min_items or 0) diff --git a/outlines/generate/json.py b/outlines/generate/json.py index 3837f72b6..b18b958d4 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -18,6 +18,7 @@ def json( schema_object: Union[str, object, Callable], sampler: Sampler = multinomial(), whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, ) -> SequenceGenerator: """ Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. @@ -33,9 +34,15 @@ def json( sampler: The sampling algorithm to use to generate token ids from the logits distribution. - whitespace_pattern + whitespace_pattern: Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization: + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. Returns ------- @@ -45,17 +52,23 @@ def json( """ if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema( + schema, whitespace_pattern, enable_schema_optimization + ) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: schema_object.parse_raw(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema( + schema, whitespace_pattern, enable_schema_optimization + ) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) elif isinstance(schema_object, str): schema = schema_object - regex_str = build_regex_from_schema(schema, whitespace_pattern) + regex_str = build_regex_from_schema( + schema, whitespace_pattern, enable_schema_optimization + ) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) else: diff --git a/outlines/integrations/llamacpp.py b/outlines/integrations/llamacpp.py index 4041c54fb..763dd3a1e 100644 --- a/outlines/integrations/llamacpp.py +++ b/outlines/integrations/llamacpp.py @@ -171,6 +171,7 @@ def __init__( schema: Union[dict, Type[BaseModel], str], llm: "Llama", whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, ): """Compile the FSM that drives the JSON-guided generation. @@ -184,9 +185,17 @@ def __init__( Pattern to use for JSON syntactic whitespace (doesn't impact string literals). For example, to allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. """ schema_str = convert_json_schema_to_str(json_schema=schema) - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + regex_string = build_regex_from_schema( + schema_str, whitespace_pattern, enable_schema_optimization + ) super().__init__(regex_string=regex_string, llm=llm) diff --git a/outlines/integrations/transformers.py b/outlines/integrations/transformers.py index 7c1bafd22..c01d3c86c 100644 --- a/outlines/integrations/transformers.py +++ b/outlines/integrations/transformers.py @@ -140,6 +140,7 @@ def __init__( schema: Union[dict, Type[BaseModel], str], tokenizer_or_pipe: Union[PreTrainedTokenizerBase, Pipeline], whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, ): """Compile the FSM that drives the JSON-guided generation. @@ -153,7 +154,15 @@ def __init__( Pattern to use for JSON syntactic whitespace (doesn't impact string literals). For example, to allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization: + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. """ schema_str = convert_json_schema_to_str(json_schema=schema) - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + regex_string = build_regex_from_schema( + schema_str, whitespace_pattern, enable_schema_optimization + ) super().__init__(regex_string=regex_string, tokenizer_or_pipe=tokenizer_or_pipe) diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index 6ed56d71b..69cb2a8f8 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -132,6 +132,7 @@ def __init__( schema: Union[dict, Type[BaseModel], str], llm: "LLM", whitespace_pattern: Optional[str] = None, + enable_schema_optimization: bool = False, ): """Compile the FSM that drives the JSON-guided generation. @@ -145,7 +146,15 @@ def __init__( Pattern to use for JSON syntactic whitespace (doesn't impact string literals). For example, to allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + enable_schema_optimization: + If True, this will speed up generation by not requiring optional keys to be + present in the output. This is especially useful for large schemas with many + optional keys. Note though that this further restricts the support + distribution. Thus, it is necessary to remove the optional keys from the + finetuning dataset as well if needed. Hence, we set this to False by default. """ schema_str = convert_json_schema_to_str(json_schema=schema) - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + regex_string = build_regex_from_schema( + schema_str, whitespace_pattern, enable_schema_optimization + ) super().__init__(regex_string=regex_string, llm=llm) diff --git a/outlines/serve/serve.py b/outlines/serve/serve.py index fb8c80139..ef885988b 100644 --- a/outlines/serve/serve.py +++ b/outlines/serve/serve.py @@ -68,8 +68,17 @@ async def generate(request: Request) -> Response: json_schema = request_dict.pop("schema", None) regex_string = request_dict.pop("regex", None) + whitespace_pattern = request_dict.pop("whitespace_pattern", None) + enable_schema_optimization = request_dict.pop("enable_schema_optimization", False) if json_schema is not None: - logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] + logits_processors = [ + JSONLogitsProcessor( + json_schema, + engine.engine, + whitespace_pattern, + enable_schema_optimization, + ) + ] elif regex_string is not None: logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)] else: diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index edc061bec..50b97d6b8 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -20,6 +20,7 @@ WHITESPACE, build_regex_from_schema, get_schema_from_signature, + optimize_schema, to_regex, ) @@ -777,3 +778,174 @@ class Model(BaseModel): # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() interegular.parse_pattern(pattern).to_fsm() + + +@pytest.mark.parametrize( + "schema,expected_schema", + [ + # No optimizations possible + ( + { + "properties": {"field_a": {"title": "Field A", "type": "integer"}}, + "required": ["field_a"], + "title": "Test", + "type": "object", + }, + { + "properties": {"field_a": {"title": "Field A", "type": "integer"}}, + "required": ["field_a"], + "title": "Test", + "type": "object", + }, + ), + # Makes fields with null type in anyOf optional + # and removes null fields + ( + { + "properties": { + "field_a": {"title": "Field A", "type": "integer"}, + "field_b": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "Field B", + }, + "field_c": {"title": "Field C", "type": "null"}, + }, + "required": ["field_a", "field_b", "field_c"], + "title": "Test", + "type": "object", + }, + { + "properties": { + "field_a": {"title": "Field A", "type": "integer"}, + "field_b": {"title": "Field B", "type": "integer"}, + }, + "required": ["field_a"], + "title": "Test", + "type": "object", + }, + ), + # Multilevel example + ( + { + "$defs": { + "TestCell": { + "properties": { + "g": {"title": "G", "type": "integer"}, + "h": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "H", + }, + }, + "required": ["g", "h"], + "title": "TestCell", + "type": "object", + }, + "TestLineItem": { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "B", + }, + "c": {"title": "C", "type": "string"}, + "d": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "title": "D", + }, + "e": {"title": "E", "type": "null"}, + "f": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "F", + }, + "i": {"$ref": "#/$defs/TestCell"}, + }, + "required": ["a", "b", "c", "d", "e", "f", "i"], + "title": "TestLineItem", + "type": "object", + }, + }, + "properties": { + "line_items": { + "anyOf": [ + { + "items": {"$ref": "#/$defs/TestLineItem"}, + "type": "array", + }, + {"type": "null"}, + ], + "title": "Line Items", + } + }, + "required": ["line_items"], + "title": "TestTable", + "type": "object", + }, + { + "$defs": { + "TestCell": { + "properties": { + "g": {"title": "G", "type": "integer"}, + "h": {"title": "H", "type": "string"}, + }, + "required": ["g"], + "title": "TestCell", + "type": "object", + }, + "TestLineItem": { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "integer"}, + "c": {"title": "C", "type": "string"}, + "d": {"title": "D", "type": "string"}, + "f": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "F", + }, + "i": {"$ref": "#/$defs/TestCell"}, + }, + "required": ["a", "c", "f", "i"], + "title": "TestLineItem", + "type": "object", + }, + }, + "properties": { + "line_items": { + "title": "Line Items", + "items": {"$ref": "#/$defs/TestLineItem"}, + "type": "array", + }, + }, + "required": [], + "title": "TestTable", + "type": "object", + }, + ), + # From function signature + ( + { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": { + "anyOf": [{"type": "integer"}, {"type": "null"}], + "title": "B", + }, + }, + "required": ["a", "b"], + "title": "Arguments", + "type": "object", + }, + { + "properties": { + "a": {"title": "A", "type": "integer"}, + "b": {"title": "B", "type": "integer"}, + }, + "required": ["a"], + "title": "Arguments", + "type": "object", + }, + ), + ], +) +def test_json_schema_optimization(schema: dict, expected_schema: dict): + optimized_schema = optimize_schema(schema) + assert optimized_schema == expected_schema