Skip to content

Commit

Permalink
take into account SequenceGeneratorAdapter type
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed May 30, 2024
1 parent f7580f6 commit 0a15685
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 35 deletions.
6 changes: 3 additions & 3 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel

from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from outlines.generate.api import SequenceGenerator
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial

Expand All @@ -18,7 +18,7 @@ def json(
schema_object: Union[str, object, Callable, Dict],
sampler: Sampler = multinomial(),
whitespace_pattern: Optional[str] = None,
) -> SequenceGenerator:
) -> Union[SequenceGenerator, SequenceGeneratorAdapter]:
"""
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
Expand All @@ -39,7 +39,7 @@ def json(
Returns
-------
A `SequenceGenerator` instance that generates text constrained by the schema_object and
A `SequenceGenerator` or `SequenceGeneratorAdapter` instance that generates text constrained by the schema_object and
transforms the result if BaseModel is used.
"""
Expand Down
32 changes: 16 additions & 16 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,28 +321,28 @@ class UserPydantic(BaseModel):

# Check finite state machines are the same
assert (
generator_callable.fsm.states_to_token_maps
== generator_pydantic.fsm.states_to_token_maps
== generator_dict.fsm.states_to_token_maps
== generator_str.fsm.states_to_token_maps
generator_callable.logits_processor.fsm.states_to_token_maps
== generator_pydantic.logits_processor.fsm.states_to_token_maps
== generator_dict.logits_processor.fsm.states_to_token_maps
== generator_str.logits_processor.fsm.states_to_token_maps
)
assert (
generator_callable.fsm.empty_token_ids
== generator_pydantic.fsm.empty_token_ids
== generator_dict.fsm.empty_token_ids
== generator_str.fsm.empty_token_ids
generator_callable.logits_processor.fsm.empty_token_ids
== generator_pydantic.logits_processor.fsm.empty_token_ids
== generator_dict.logits_processor.fsm.empty_token_ids
== generator_str.logits_processor.fsm.empty_token_ids
)
assert (
generator_callable.fsm.eos_token_id
== generator_pydantic.fsm.eos_token_id
== generator_dict.fsm.eos_token_id
== generator_str.fsm.eos_token_id
generator_callable.logits_processor.fsm.eos_token_id
== generator_pydantic.logits_processor.fsm.eos_token_id
== generator_dict.logits_processor.fsm.eos_token_id
== generator_str.logits_processor.fsm.eos_token_id
)
assert (
generator_callable.fsm.final_states
== generator_pydantic.fsm.final_states
== generator_dict.fsm.final_states
== generator_str.fsm.final_states
generator_callable.logits_processor.fsm.final_states
== generator_pydantic.logits_processor.fsm.final_states
== generator_dict.logits_processor.fsm.final_states
== generator_str.logits_processor.fsm.final_states
)


Expand Down
32 changes: 16 additions & 16 deletions tests/generate/test_integration_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,28 +309,28 @@ class UserPydantic(BaseModel):

# Check finite state machines are the same
assert (
generator_callable.fsm.states_to_token_maps
== generator_pydantic.fsm.states_to_token_maps
== generator_dict.fsm.states_to_token_maps
== generator_str.fsm.states_to_token_maps
generator_callable.logits_processor.fsm.states_to_token_maps
== generator_pydantic.logits_processor.fsm.states_to_token_maps
== generator_dict.logits_processor.fsm.states_to_token_maps
== generator_str.logits_processor.fsm.states_to_token_maps
)
assert (
generator_callable.fsm.empty_token_ids
== generator_pydantic.fsm.empty_token_ids
== generator_dict.fsm.empty_token_ids
== generator_str.fsm.empty_token_ids
generator_callable.logits_processor.fsm.empty_token_ids
== generator_pydantic.logits_processor.fsm.empty_token_ids
== generator_dict.logits_processor.fsm.empty_token_ids
== generator_str.logits_processor.fsm.empty_token_ids
)
assert (
generator_callable.fsm.eos_token_id
== generator_pydantic.fsm.eos_token_id
== generator_dict.fsm.eos_token_id
== generator_str.fsm.eos_token_id
generator_callable.logits_processor.fsm.eos_token_id
== generator_pydantic.logits_processor.fsm.eos_token_id
== generator_dict.logits_processor.fsm.eos_token_id
== generator_str.logits_processor.fsm.eos_token_id
)
assert (
generator_callable.fsm.final_states
== generator_pydantic.fsm.final_states
== generator_dict.fsm.final_states
== generator_str.fsm.final_states
generator_callable.logits_processor.fsm.final_states
== generator_pydantic.logits_processor.fsm.final_states
== generator_dict.logits_processor.fsm.final_states
== generator_str.logits_processor.fsm.final_states
)


Expand Down

0 comments on commit 0a15685

Please sign in to comment.