From 0a156850594d0962e81d1bfcb0c44c2dd245c9db Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 30 May 2024 10:03:48 -0400 Subject: [PATCH] take into account SequenceGeneratorAdapter type --- outlines/generate/json.py | 6 ++-- tests/generate/test_integration_llamacpp.py | 32 ++++++++++----------- tests/generate/test_integration_vllm.py | 32 ++++++++++----------- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/outlines/generate/json.py b/outlines/generate/json.py index eca8d7d17..75f83f69e 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature -from outlines.generate.api import SequenceGenerator +from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter from outlines.models import OpenAI from outlines.samplers import Sampler, multinomial @@ -18,7 +18,7 @@ def json( schema_object: Union[str, object, Callable, Dict], sampler: Sampler = multinomial(), whitespace_pattern: Optional[str] = None, -) -> SequenceGenerator: +) -> Union[SequenceGenerator, SequenceGeneratorAdapter]: """ Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. @@ -39,7 +39,7 @@ def json( Returns ------- - A `SequenceGenerator` instance that generates text constrained by the schema_object and + A `SequenceGenerator` or `SequenceGeneratorAdapter` instance that generates text constrained by the schema_object and transforms the result if BaseModel is used. """ diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 0203ddefc..31003e753 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -321,28 +321,28 @@ class UserPydantic(BaseModel): # Check finite state machines are the same assert ( - generator_callable.fsm.states_to_token_maps - == generator_pydantic.fsm.states_to_token_maps - == generator_dict.fsm.states_to_token_maps - == generator_str.fsm.states_to_token_maps + generator_callable.logits_processor.fsm.states_to_token_maps + == generator_pydantic.logits_processor.fsm.states_to_token_maps + == generator_dict.logits_processor.fsm.states_to_token_maps + == generator_str.logits_processor.fsm.states_to_token_maps ) assert ( - generator_callable.fsm.empty_token_ids - == generator_pydantic.fsm.empty_token_ids - == generator_dict.fsm.empty_token_ids - == generator_str.fsm.empty_token_ids + generator_callable.logits_processor.fsm.empty_token_ids + == generator_pydantic.logits_processor.fsm.empty_token_ids + == generator_dict.logits_processor.fsm.empty_token_ids + == generator_str.logits_processor.fsm.empty_token_ids ) assert ( - generator_callable.fsm.eos_token_id - == generator_pydantic.fsm.eos_token_id - == generator_dict.fsm.eos_token_id - == generator_str.fsm.eos_token_id + generator_callable.logits_processor.fsm.eos_token_id + == generator_pydantic.logits_processor.fsm.eos_token_id + == generator_dict.logits_processor.fsm.eos_token_id + == generator_str.logits_processor.fsm.eos_token_id ) assert ( - generator_callable.fsm.final_states - == generator_pydantic.fsm.final_states - == generator_dict.fsm.final_states - == generator_str.fsm.final_states + generator_callable.logits_processor.fsm.final_states + == generator_pydantic.logits_processor.fsm.final_states + == generator_dict.logits_processor.fsm.final_states + == generator_str.logits_processor.fsm.final_states ) diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index e4d703a18..8b27166a4 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -309,28 +309,28 @@ class UserPydantic(BaseModel): # Check finite state machines are the same assert ( - generator_callable.fsm.states_to_token_maps - == generator_pydantic.fsm.states_to_token_maps - == generator_dict.fsm.states_to_token_maps - == generator_str.fsm.states_to_token_maps + generator_callable.logits_processor.fsm.states_to_token_maps + == generator_pydantic.logits_processor.fsm.states_to_token_maps + == generator_dict.logits_processor.fsm.states_to_token_maps + == generator_str.logits_processor.fsm.states_to_token_maps ) assert ( - generator_callable.fsm.empty_token_ids - == generator_pydantic.fsm.empty_token_ids - == generator_dict.fsm.empty_token_ids - == generator_str.fsm.empty_token_ids + generator_callable.logits_processor.fsm.empty_token_ids + == generator_pydantic.logits_processor.fsm.empty_token_ids + == generator_dict.logits_processor.fsm.empty_token_ids + == generator_str.logits_processor.fsm.empty_token_ids ) assert ( - generator_callable.fsm.eos_token_id - == generator_pydantic.fsm.eos_token_id - == generator_dict.fsm.eos_token_id - == generator_str.fsm.eos_token_id + generator_callable.logits_processor.fsm.eos_token_id + == generator_pydantic.logits_processor.fsm.eos_token_id + == generator_dict.logits_processor.fsm.eos_token_id + == generator_str.logits_processor.fsm.eos_token_id ) assert ( - generator_callable.fsm.final_states - == generator_pydantic.fsm.final_states - == generator_dict.fsm.final_states - == generator_str.fsm.final_states + generator_callable.logits_processor.fsm.final_states + == generator_pydantic.logits_processor.fsm.final_states + == generator_dict.logits_processor.fsm.final_states + == generator_str.logits_processor.fsm.final_states )